Add getdotattr

This commit is contained in:
Konsta Vesterinen
2014-02-18 11:36:19 +02:00
parent 1b344ff354
commit 09c2d7ddc7
3 changed files with 107 additions and 18 deletions

View File

@@ -12,37 +12,39 @@ from .database import (
non_indexed_foreign_keys, non_indexed_foreign_keys,
) )
from .orm import ( from .orm import (
primary_keys,
table_name,
declarative_base, declarative_base,
getdotattr,
has_changes, has_changes,
identity, identity,
naturally_equivalent, naturally_equivalent,
remove_property primary_keys,
remove_property,
table_name,
) )
__all__ = ( __all__ = (
create_mock_engine,
defer_except,
mock_engine,
sort_query,
render_expression,
render_statement,
QuerySorterException,
database_exists,
create_database, create_database,
create_mock_engine,
database_exists,
declarative_base,
defer_except,
drop_database, drop_database,
escape_like, escape_like,
is_auto_assigned_date_column, getdotattr,
is_indexed_foreign_key,
non_indexed_foreign_keys,
remove_property,
primary_keys,
table_name,
declarative_base,
has_changes, has_changes,
identity, identity,
is_auto_assigned_date_column,
is_indexed_foreign_key,
mock_engine,
naturally_equivalent, naturally_equivalent,
non_indexed_foreign_keys,
primary_keys,
QuerySorterException,
remove_property,
render_expression,
render_statement,
sort_query,
table_name,
) )

View File

@@ -1,7 +1,9 @@
from functools import partial from functools import partial
from toolz import curry, first from toolz import curry, first
import six
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import inspect from sqlalchemy import inspect
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.query import _ColumnEntity from sqlalchemy.orm.query import _ColumnEntity
from sqlalchemy.orm.mapper import Mapper from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.util import AliasedInsp from sqlalchemy.orm.util import AliasedInsp
@@ -218,6 +220,32 @@ def declarative_base(model):
return model return model
def getdotattr(obj_or_class, dot_path):
"""
Allow dot-notated strings to be passed to `getattr`.
::
getdotattr(SubSection, 'section.document')
getdotattr(subsection, 'section.document')
:param obj_or_class: Any object or class
:param dot_path: Attribute path with dot mark as separator
"""
def get_attr(mixed, attr):
if isinstance(mixed, InstrumentedAttribute):
return getattr(
mixed.property.mapper.class_,
attr
)
else:
return getattr(mixed, attr)
return six.moves.reduce(get_attr, dot_path.split('.'), obj_or_class)
def has_changes(obj, attr): def has_changes(obj, attr):
""" """
Simple shortcut function for checking if given attribute of given Simple shortcut function for checking if given attribute of given

View File

@@ -0,0 +1,59 @@
import sqlalchemy as sa
from sqlalchemy_utils.functions import getdotattr
from tests import TestCase
class TestTwoWayAttributeValueGeneration(TestCase):
def create_models(self):
class Document(self.Base):
__tablename__ = 'document'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
class Section(self.Base):
__tablename__ = 'section'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
document_id = sa.Column(
sa.Integer, sa.ForeignKey(Document.id)
)
document = sa.orm.relationship(Document)
class SubSection(self.Base):
__tablename__ = 'subsection'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
locale = sa.Column(sa.String(10))
section_id = sa.Column(
sa.Integer, sa.ForeignKey(Section.id)
)
section = sa.orm.relationship(Section)
self.Document = Document
self.Section = Section
self.SubSection = SubSection
def test_getdotattr_for_objects(self):
document = self.Document(name=u'some document')
section = self.Section(document=document)
subsection = self.SubSection(section=section)
assert getdotattr(
subsection,
'section.document.name'
) == u'some document'
def test_getdotattr_for_class_paths(self):
assert getdotattr(self.Section, 'document') is self.Section.document
assert (
getdotattr(self.SubSection, 'section.document') is
self.Section.document
)
assert getdotattr(self.Section, 'document.name') is self.Document.name