Add condition parameter for getdotattr

This commit is contained in:
Konsta Vesterinen
2014-12-12 10:11:19 +02:00
parent 54a713959f
commit 2c22e69edc
3 changed files with 33 additions and 7 deletions

View File

@@ -3,7 +3,6 @@ import itertools
import sqlalchemy as sa
import six
from .functions import getdotattr
from .path import AttrPath
class AttributeValueGenerator(object):

View File

@@ -16,6 +16,7 @@ from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy.orm.query import _ColumnEntity
from sqlalchemy.orm.session import object_session
from sqlalchemy.orm.util import AliasedInsp
from sqlalchemy_utils.utils import is_sequence
def get_column_key(model, column):
@@ -623,7 +624,7 @@ def get_declarative_base(model):
return model
def getdotattr(obj_or_class, dot_path):
def getdotattr(obj_or_class, dot_path, condition=None):
"""
Allow dot-notated strings to be passed to `getattr`.
@@ -638,22 +639,39 @@ def getdotattr(obj_or_class, dot_path):
:param dot_path: Attribute path with dot mark as separator
"""
last = obj_or_class
# Coerce object style paths to strings.
path = str(dot_path)
for path in dot_path.split('.'):
for path in str(dot_path).split('.'):
getter = attrgetter(path)
if isinstance(last, list):
last = sum((getter(el) for el in last), [])
if is_sequence(last):
tmp = []
for element in last:
value = getter(element)
if is_sequence(value):
tmp.extend(value)
else:
tmp.append(value)
last = tmp
elif isinstance(last, InstrumentedAttribute):
last = getter(last.property.mapper.class_)
elif last is None:
return None
else:
last = getter(last)
if condition is not None:
if is_sequence(last):
last = [v for v in last if condition(v)]
else:
if not condition(last):
return None
return last
def is_deleted(obj):
return obj in sa.orm.object_session(obj).deleted
def has_changes(obj, attrs=None, exclude=None):
"""
Simple shortcut function for checking if given attributes of given

View File

@@ -1,4 +1,7 @@
import sys
from collections import Iterable
import six
def str_coercible(cls):
@@ -11,3 +14,9 @@ def str_coercible(cls):
cls.__str__ = __str__
return cls
def is_sequence(value):
return (
isinstance(value, Iterable) and not isinstance(value, six.string_types)
)