diff --git a/sqlalchemy_utils/decorators.py b/sqlalchemy_utils/decorators.py index af957b5..9a8a071 100644 --- a/sqlalchemy_utils/decorators.py +++ b/sqlalchemy_utils/decorators.py @@ -3,7 +3,6 @@ import itertools import sqlalchemy as sa import six from .functions import getdotattr -from .path import AttrPath class AttributeValueGenerator(object): diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index 86ba278..12cb752 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -16,6 +16,7 @@ from sqlalchemy.orm.properties import ColumnProperty from sqlalchemy.orm.query import _ColumnEntity from sqlalchemy.orm.session import object_session from sqlalchemy.orm.util import AliasedInsp +from sqlalchemy_utils.utils import is_sequence def get_column_key(model, column): @@ -623,7 +624,7 @@ def get_declarative_base(model): return model -def getdotattr(obj_or_class, dot_path): +def getdotattr(obj_or_class, dot_path, condition=None): """ Allow dot-notated strings to be passed to `getattr`. @@ -638,22 +639,39 @@ def getdotattr(obj_or_class, dot_path): :param dot_path: Attribute path with dot mark as separator """ last = obj_or_class - # Coerce object style paths to strings. - path = str(dot_path) - for path in dot_path.split('.'): + for path in str(dot_path).split('.'): getter = attrgetter(path) - if isinstance(last, list): - last = sum((getter(el) for el in last), []) + + if is_sequence(last): + tmp = [] + for element in last: + value = getter(element) + if is_sequence(value): + tmp.extend(value) + else: + tmp.append(value) + last = tmp elif isinstance(last, InstrumentedAttribute): last = getter(last.property.mapper.class_) elif last is None: return None else: last = getter(last) + if condition is not None: + if is_sequence(last): + last = [v for v in last if condition(v)] + else: + if not condition(last): + return None + return last +def is_deleted(obj): + return obj in sa.orm.object_session(obj).deleted + + def has_changes(obj, attrs=None, exclude=None): """ Simple shortcut function for checking if given attributes of given diff --git a/sqlalchemy_utils/utils.py b/sqlalchemy_utils/utils.py index 4efed7b..8c5b6b4 100644 --- a/sqlalchemy_utils/utils.py +++ b/sqlalchemy_utils/utils.py @@ -1,4 +1,7 @@ import sys +from collections import Iterable + +import six def str_coercible(cls): @@ -11,3 +14,9 @@ def str_coercible(cls): cls.__str__ = __str__ return cls + + +def is_sequence(value): + return ( + isinstance(value, Iterable) and not isinstance(value, six.string_types) + )