Add condition parameter for getdotattr
This commit is contained in:
@@ -3,7 +3,6 @@ import itertools
|
|||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
import six
|
import six
|
||||||
from .functions import getdotattr
|
from .functions import getdotattr
|
||||||
from .path import AttrPath
|
|
||||||
|
|
||||||
|
|
||||||
class AttributeValueGenerator(object):
|
class AttributeValueGenerator(object):
|
||||||
|
@@ -16,6 +16,7 @@ from sqlalchemy.orm.properties import ColumnProperty
|
|||||||
from sqlalchemy.orm.query import _ColumnEntity
|
from sqlalchemy.orm.query import _ColumnEntity
|
||||||
from sqlalchemy.orm.session import object_session
|
from sqlalchemy.orm.session import object_session
|
||||||
from sqlalchemy.orm.util import AliasedInsp
|
from sqlalchemy.orm.util import AliasedInsp
|
||||||
|
from sqlalchemy_utils.utils import is_sequence
|
||||||
|
|
||||||
|
|
||||||
def get_column_key(model, column):
|
def get_column_key(model, column):
|
||||||
@@ -623,7 +624,7 @@ def get_declarative_base(model):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def getdotattr(obj_or_class, dot_path):
|
def getdotattr(obj_or_class, dot_path, condition=None):
|
||||||
"""
|
"""
|
||||||
Allow dot-notated strings to be passed to `getattr`.
|
Allow dot-notated strings to be passed to `getattr`.
|
||||||
|
|
||||||
@@ -638,22 +639,39 @@ def getdotattr(obj_or_class, dot_path):
|
|||||||
:param dot_path: Attribute path with dot mark as separator
|
:param dot_path: Attribute path with dot mark as separator
|
||||||
"""
|
"""
|
||||||
last = obj_or_class
|
last = obj_or_class
|
||||||
# Coerce object style paths to strings.
|
|
||||||
path = str(dot_path)
|
|
||||||
|
|
||||||
for path in dot_path.split('.'):
|
for path in str(dot_path).split('.'):
|
||||||
getter = attrgetter(path)
|
getter = attrgetter(path)
|
||||||
if isinstance(last, list):
|
|
||||||
last = sum((getter(el) for el in last), [])
|
if is_sequence(last):
|
||||||
|
tmp = []
|
||||||
|
for element in last:
|
||||||
|
value = getter(element)
|
||||||
|
if is_sequence(value):
|
||||||
|
tmp.extend(value)
|
||||||
|
else:
|
||||||
|
tmp.append(value)
|
||||||
|
last = tmp
|
||||||
elif isinstance(last, InstrumentedAttribute):
|
elif isinstance(last, InstrumentedAttribute):
|
||||||
last = getter(last.property.mapper.class_)
|
last = getter(last.property.mapper.class_)
|
||||||
elif last is None:
|
elif last is None:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
last = getter(last)
|
last = getter(last)
|
||||||
|
if condition is not None:
|
||||||
|
if is_sequence(last):
|
||||||
|
last = [v for v in last if condition(v)]
|
||||||
|
else:
|
||||||
|
if not condition(last):
|
||||||
|
return None
|
||||||
|
|
||||||
return last
|
return last
|
||||||
|
|
||||||
|
|
||||||
|
def is_deleted(obj):
|
||||||
|
return obj in sa.orm.object_session(obj).deleted
|
||||||
|
|
||||||
|
|
||||||
def has_changes(obj, attrs=None, exclude=None):
|
def has_changes(obj, attrs=None, exclude=None):
|
||||||
"""
|
"""
|
||||||
Simple shortcut function for checking if given attributes of given
|
Simple shortcut function for checking if given attributes of given
|
||||||
|
@@ -1,4 +1,7 @@
|
|||||||
import sys
|
import sys
|
||||||
|
from collections import Iterable
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
|
|
||||||
def str_coercible(cls):
|
def str_coercible(cls):
|
||||||
@@ -11,3 +14,9 @@ def str_coercible(cls):
|
|||||||
|
|
||||||
cls.__str__ = __str__
|
cls.__str__ = __str__
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
def is_sequence(value):
|
||||||
|
return (
|
||||||
|
isinstance(value, Iterable) and not isinstance(value, six.string_types)
|
||||||
|
)
|
||||||
|
Reference in New Issue
Block a user