diff --git a/taskflow/persistence/backends/impl_zookeeper.py b/taskflow/persistence/backends/impl_zookeeper.py index 024398d2..e60bad85 100644 --- a/taskflow/persistence/backends/impl_zookeeper.py +++ b/taskflow/persistence/backends/impl_zookeeper.py @@ -156,8 +156,10 @@ class ZkConnection(base.Connection): def update_atom_details(self, ad): """Update a atom detail transactionally.""" with self._exc_wrapper(): - with self._client.transaction() as txn: - return self._update_atom_details(ad, txn) + txn = self._client.transaction() + ad = self._update_atom_details(ad, txn) + k_utils.checked_commit(txn) + return ad def _update_atom_details(self, ad, txn, create_missing=False): # Determine whether the desired data exists or not. @@ -209,8 +211,10 @@ class ZkConnection(base.Connection): def update_flow_details(self, fd): """Update a flow detail transactionally.""" with self._exc_wrapper(): - with self._client.transaction() as txn: - return self._update_flow_details(fd, txn) + txn = self._client.transaction() + fd = self._update_flow_details(fd, txn) + k_utils.checked_commit(txn) + return fd def _update_flow_details(self, fd, txn, create_missing=False): # Determine whether the desired data exists or not @@ -306,19 +310,19 @@ class ZkConnection(base.Connection): return e_lb with self._exc_wrapper(): - with self._client.transaction() as txn: - # Determine whether the desired data exists or not. - lb_path = paths.join(self.book_path, lb.uuid) - try: - lb_data, _zstat = self._client.get(lb_path) - except k_exc.NoNodeError: - # Create a new logbook since it doesn't exist. - e_lb = _create_logbook(lb_path, txn) - else: - # Otherwise update the existing logbook instead. - e_lb = _update_logbook(lb_path, lb_data, txn) - # Finally return (updated) logbook. - return e_lb + txn = self._client.transaction() + # Determine whether the desired data exists or not. + lb_path = paths.join(self.book_path, lb.uuid) + try: + lb_data, _zstat = self._client.get(lb_path) + except k_exc.NoNodeError: + # Create a new logbook since it doesn't exist. + e_lb = _create_logbook(lb_path, txn) + else: + # Otherwise update the existing logbook instead. + e_lb = _update_logbook(lb_path, lb_data, txn) + k_utils.checked_commit(txn) + return e_lb def _get_logbook(self, lb_uuid): lb_path = paths.join(self.book_path, lb_uuid) @@ -380,35 +384,38 @@ class ZkConnection(base.Connection): txn.delete(lb_path) with self._exc_wrapper(): - with self._client.transaction() as txn: - _destroy_logbook(lb_uuid, txn) + txn = self._client.transaction() + _destroy_logbook(lb_uuid, txn) + k_utils.checked_commit(txn) def clear_all(self, delete_dirs=True): """Delete all data transactionally.""" with self._exc_wrapper(): - with self._client.transaction() as txn: + txn = self._client.transaction() - # Delete all data under logbook path. - for lb_uuid in self._client.get_children(self.book_path): - lb_path = paths.join(self.book_path, lb_uuid) - for fd_uuid in self._client.get_children(lb_path): - txn.delete(paths.join(lb_path, fd_uuid)) - txn.delete(lb_path) + # Delete all data under logbook path. + for lb_uuid in self._client.get_children(self.book_path): + lb_path = paths.join(self.book_path, lb_uuid) + for fd_uuid in self._client.get_children(lb_path): + txn.delete(paths.join(lb_path, fd_uuid)) + txn.delete(lb_path) - # Delete all data under flow detail path. - for fd_uuid in self._client.get_children(self.flow_path): - fd_path = paths.join(self.flow_path, fd_uuid) - for ad_uuid in self._client.get_children(fd_path): - txn.delete(paths.join(fd_path, ad_uuid)) - txn.delete(fd_path) + # Delete all data under flow detail path. + for fd_uuid in self._client.get_children(self.flow_path): + fd_path = paths.join(self.flow_path, fd_uuid) + for ad_uuid in self._client.get_children(fd_path): + txn.delete(paths.join(fd_path, ad_uuid)) + txn.delete(fd_path) - # Delete all data under atom detail path. - for ad_uuid in self._client.get_children(self.atom_path): - ad_path = paths.join(self.atom_path, ad_uuid) - txn.delete(ad_path) + # Delete all data under atom detail path. + for ad_uuid in self._client.get_children(self.atom_path): + ad_path = paths.join(self.atom_path, ad_uuid) + txn.delete(ad_path) - # Delete containing directories. - if delete_dirs: - txn.delete(self.book_path) - txn.delete(self.atom_path) - txn.delete(self.flow_path) + # Delete containing directories. + if delete_dirs: + txn.delete(self.book_path) + txn.delete(self.atom_path) + txn.delete(self.flow_path) + + k_utils.checked_commit(txn) diff --git a/taskflow/utils/kazoo_utils.py b/taskflow/utils/kazoo_utils.py index 84f6b262..ae62e880 100644 --- a/taskflow/utils/kazoo_utils.py +++ b/taskflow/utils/kazoo_utils.py @@ -15,9 +15,11 @@ # under the License. from kazoo import client +from kazoo import exceptions as k_exc import six from taskflow import exceptions as exc +from taskflow.utils import reflection def _parse_hosts(hosts): @@ -33,6 +35,92 @@ def _parse_hosts(hosts): return hosts +def prettify_failures(failures, limit=-1): + """Prettifies a checked commits failures (ignores sensitive data...). + + Example input and output: + + >>> from taskflow.utils import kazoo_utils + >>> conf = {"hosts": ['localhost:2181']} + >>> c = kazoo_utils.make_client(conf) + >>> c.start(timeout=1) + >>> txn = c.transaction() + >>> txn.create("/test") + >>> txn.check("/test", 2) + >>> txn.delete("/test") + >>> try: + ... kazoo_utils.checked_commit(txn) + ... except kazoo_utils.KazooTransactionException as e: + ... print(kazoo_utils.prettify_failures(e.failures, limit=1)) + ... + RolledBackError@Create(path='/test') and 2 more... + >>> c.stop() + >>> c.close() + """ + prettier = [] + for (op, r) in failures: + pretty_op = reflection.get_class_name(op, fully_qualified=False) + # Pick off a few attributes that are meaningful (but one that don't + # show actual data, which might not be desired to show...). + selected_attrs = [ + "path=%r" % op.path, + ] + try: + if op.version != -1: + selected_attrs.append("version=%s" % op.version) + except AttributeError: + pass + pretty_op += "(%s)" % (", ".join(selected_attrs)) + pretty_cause = reflection.get_class_name(r, fully_qualified=False) + prettier.append("%s@%s" % (pretty_cause, pretty_op)) + if limit <= 0 or len(prettier) <= limit: + return ", ".join(prettier) + else: + leftover = prettier[limit:] + prettier = prettier[0:limit] + return ", ".join(prettier) + " and %s more..." % len(leftover) + + +class KazooTransactionException(k_exc.KazooException): + """Exception raised when a checked commit fails.""" + + def __init__(self, message, failures): + super(KazooTransactionException, self).__init__(message) + self._failures = tuple(failures) + + @property + def failures(self): + return self._failures + + +def checked_commit(txn): + # Until https://github.com/python-zk/kazoo/pull/224 is fixed we have + # to workaround the transaction failing silently. + if not txn.operations: + return [] + results = txn.commit() + failures = [] + for op, result in six.moves.zip(txn.operations, results): + if isinstance(result, k_exc.KazooException): + failures.append((op, result)) + if len(results) < len(txn.operations): + raise KazooTransactionException( + "Transaction returned %s results, this is less than" + " the number of expected transaction operations %s" + % (len(results), len(txn.operations)), failures) + if len(results) > len(txn.operations): + raise KazooTransactionException( + "Transaction returned %s results, this is greater than" + " the number of expected transaction operations %s" + % (len(results), len(txn.operations)), failures) + if failures: + raise KazooTransactionException( + "Transaction with %s operations failed: %s" + % (len(txn.operations), + prettify_failures(failures, limit=1)), failures) + return results + + def finalize_client(client): """Stops and closes a client, even if it wasn't started.""" client.stop() diff --git a/taskflow/utils/reflection.py b/taskflow/utils/reflection.py index c7f1a06a..b386dfa2 100644 --- a/taskflow/utils/reflection.py +++ b/taskflow/utils/reflection.py @@ -77,7 +77,7 @@ def get_member_names(obj, exclude_hidden=True): return [name for (name, _obj) in _get_members(obj, exclude_hidden)] -def get_class_name(obj): +def get_class_name(obj, fully_qualified=True): """Get class name for object. If object is a type, fully qualified name of the type is returned. @@ -88,7 +88,10 @@ def get_class_name(obj): obj = type(obj) if obj.__module__ in ('builtins', '__builtin__', 'exceptions'): return obj.__name__ - return '.'.join((obj.__module__, obj.__name__)) + if fully_qualified: + return '.'.join((obj.__module__, obj.__name__)) + else: + return obj.__name__ def get_all_class_names(obj, up_to=object):