diff --git a/taskflow/persistence/backends/impl_memory.py b/taskflow/persistence/backends/impl_memory.py index 6b937f4a..b51c22f1 100644 --- a/taskflow/persistence/backends/impl_memory.py +++ b/taskflow/persistence/backends/impl_memory.py @@ -19,7 +19,6 @@ """Implementation of in-memory backend.""" -import copy import logging import threading @@ -27,31 +26,42 @@ from taskflow import decorators from taskflow import exceptions as exc from taskflow.openstack.common import timeutils from taskflow.persistence.backends import base +from taskflow.persistence import logbook from taskflow.utils import persistence_utils as p_utils LOG = logging.getLogger(__name__) -# TODO(harlowja): we likely need to figure out a better place to put these -# rather than globals. -_LOG_BOOKS = {} -_FLOW_DETAILS = {} -_TASK_DETAILS = {} - -# For now this will be a pretty big lock, since it is not expected that saves -# will be that frequent this seems ok for the time being. I imagine that this -# can be done better but it will require much more careful usage of a dict as -# a key/value map. Aka I wish python had a concurrent dict that was safe and -# known good to use. -_SAVE_LOCK = threading.RLock() -_READ_LOCK = threading.RLock() -_READ_SAVE_ORDER = (_READ_LOCK, _SAVE_LOCK) - - -def _copy(obj): - return copy.deepcopy(obj) - class MemoryBackend(base.Backend): + def __init__(self, conf): + super(MemoryBackend, self).__init__(conf) + self._log_books = {} + self._flow_details = {} + self._task_details = {} + self._save_lock = threading.RLock() + self._read_lock = threading.RLock() + self._read_save_order = (self._read_lock, self._save_lock) + + @property + def log_books(self): + return self._log_books + + @property + def flow_details(self): + return self._flow_details + + @property + def task_details(self): + return self._task_details + + @property + def read_locks(self): + return (self._read_lock,) + + @property + def save_locks(self): + return self._read_save_order + def get_connection(self): return Connection(self) @@ -61,8 +71,8 @@ class MemoryBackend(base.Backend): class Connection(base.Connection): def __init__(self, backend): - self._read_lock = _READ_LOCK - self._save_locks = _READ_SAVE_ORDER + self._read_locks = backend.read_locks + self._save_locks = backend.save_locks self._backend = backend def upgrade(self): @@ -78,7 +88,7 @@ class Connection(base.Connection): @decorators.locked(lock="_save_locks") def clear_all(self): count = 0 - for uuid in list(_LOG_BOOKS.iterkeys()): + for uuid in list(self.backend.log_books.keys()): self.destroy_logbook(uuid) count += 1 return count @@ -87,80 +97,87 @@ class Connection(base.Connection): def destroy_logbook(self, book_uuid): try: # Do the same cascading delete that the sql layer does. - lb = _LOG_BOOKS.pop(book_uuid) + lb = self.backend.log_books.pop(book_uuid) for fd in lb: - _FLOW_DETAILS.pop(fd.uuid, None) + self.backend.flow_details.pop(fd.uuid, None) for td in fd: - _TASK_DETAILS.pop(td.uuid, None) + self.backend.task_details.pop(td.uuid, None) except KeyError: raise exc.NotFound("No logbook found with id: %s" % book_uuid) @decorators.locked(lock="_save_locks") def update_task_details(self, task_detail): try: - return p_utils.task_details_merge(_TASK_DETAILS[task_detail.uuid], - task_detail) + e_td = self.backend.task_details[task_detail.uuid] except KeyError: raise exc.NotFound("No task details found with id: %s" % task_detail.uuid) + return p_utils.task_details_merge(e_td, task_detail, deep_copy=True) + + def _save_flowdetail_tasks(self, e_fd, flow_detail): + for task_detail in flow_detail: + e_td = e_fd.find(task_detail.uuid) + if e_td is None: + e_td = logbook.TaskDetail(name=task_detail.name, + uuid=task_detail.uuid) + e_fd.add(e_td) + if task_detail.uuid not in self.backend.task_details: + self.backend.task_details[task_detail.uuid] = e_td + p_utils.task_details_merge(e_td, task_detail, deep_copy=True) @decorators.locked(lock="_save_locks") def update_flow_details(self, flow_detail): try: - e_fd = p_utils.flow_details_merge(_FLOW_DETAILS[flow_detail.uuid], - flow_detail) - for task_detail in flow_detail: - if e_fd.find(task_detail.uuid) is None: - _TASK_DETAILS[task_detail.uuid] = _copy(task_detail) - e_fd.add(task_detail) - if task_detail.uuid not in _TASK_DETAILS: - _TASK_DETAILS[task_detail.uuid] = _copy(task_detail) - task_detail.update(self.update_task_details(task_detail)) - return e_fd + e_fd = self.backend.flow_details[flow_detail.uuid] except KeyError: raise exc.NotFound("No flow details found with id: %s" % flow_detail.uuid) + p_utils.flow_details_merge(e_fd, flow_detail, deep_copy=True) + self._save_flowdetail_tasks(e_fd, flow_detail) + return e_fd @decorators.locked(lock="_save_locks") def save_logbook(self, book): # Get a existing logbook model (or create it if it isn't there). try: - e_lb = p_utils.logbook_merge(_LOG_BOOKS[book.uuid], book) - # Add anything in to the new logbook that isn't already - # in the existing logbook. - for flow_detail in book: - if e_lb.find(flow_detail.uuid) is None: - _FLOW_DETAILS[flow_detail.uuid] = _copy(flow_detail) - e_lb.add(flow_detail) - if flow_detail.uuid not in _FLOW_DETAILS: - _FLOW_DETAILS[flow_detail.uuid] = _copy(flow_detail) - flow_detail.update(self.update_flow_details(flow_detail)) + e_lb = self.backend.log_books[book.uuid] + except KeyError: + e_lb = logbook.LogBook(book.name, book.uuid, + updated_at=book.updated_at, + created_at=timeutils.utcnow()) + self.backend.log_books[e_lb.uuid] = e_lb + else: # TODO(harlowja): figure out a better way to set this property # without actually setting a 'private' property. e_lb._updated_at = timeutils.utcnow() - except KeyError: - # Ok the one given is now the one we will save - e_lb = _copy(book) - # TODO(harlowja): figure out a better way to set this property - # without actually setting a 'private' property. - e_lb._created_at = timeutils.utcnow() - # Record all the pieces as being saved. - _LOG_BOOKS[e_lb.uuid] = e_lb - for flow_detail in e_lb: - _FLOW_DETAILS[flow_detail.uuid] = _copy(flow_detail) - flow_detail.update(self.update_flow_details(flow_detail)) + + p_utils.logbook_merge(e_lb, book, deep_copy=True) + # Add anything in to the new logbook that isn't already + # in the existing logbook. + for flow_detail in book: + try: + e_fd = self.backend.flow_details[flow_detail.uuid] + except KeyError: + e_fd = logbook.FlowDetail(name=flow_detail.name, + uuid=flow_detail.uuid) + e_lb.add(flow_detail) + self.backend.flow_details[flow_detail.uuid] = e_fd + p_utils.flow_details_merge(e_fd, flow_detail, deep_copy=True) + self._save_flowdetail_tasks(e_fd, flow_detail) return e_lb - @decorators.locked(lock='_read_lock') + @decorators.locked(lock='_read_locks') def get_logbook(self, book_uuid): try: - return _LOG_BOOKS[book_uuid] + return self.backend.log_books[book_uuid] except KeyError: raise exc.NotFound("No logbook found with id: %s" % book_uuid) + @decorators.locked(lock='_read_locks') + def _get_logbooks(self): + return list(self.backend.log_books.values()) + def get_logbooks(self): # NOTE(harlowja): don't hold the lock while iterating - with self._read_lock: - books = list(_LOG_BOOKS.values()) - for lb in books: + for lb in self._get_logbooks(): yield lb diff --git a/taskflow/tests/unit/persistence/test_memory_persistence.py b/taskflow/tests/unit/persistence/test_memory_persistence.py index da088c13..cdc9bbcf 100644 --- a/taskflow/tests/unit/persistence/test_memory_persistence.py +++ b/taskflow/tests/unit/persistence/test_memory_persistence.py @@ -22,10 +22,14 @@ from taskflow.tests.unit.persistence import base class MemoryPersistenceTest(test.TestCase, base.PersistenceTestMixin): + def setUp(self): + self._backend = impl_memory.MemoryBackend({}) + def _get_connection(self): - return impl_memory.MemoryBackend({}).get_connection() + return self._backend.get_connection() def tearDown(self): conn = self._get_connection() conn.clear_all() + self._backend = None super(MemoryPersistenceTest, self).tearDown() diff --git a/taskflow/utils/persistence_utils.py b/taskflow/utils/persistence_utils.py index 8562e96e..fa6f5f99 100644 --- a/taskflow/utils/persistence_utils.py +++ b/taskflow/utils/persistence_utils.py @@ -17,6 +17,7 @@ # under the License. import contextlib +import copy import logging from taskflow.openstack.common import uuidutils @@ -84,47 +85,71 @@ def create_flow_detail(flow, book=None, backend=None): return flow_detail -def task_details_merge(td_e, td_new): - """Merges an existing task details with a new task details object, the new - task details fields, if they differ will replace the existing objects - fields (except name, version, uuid which can not be replaced). +def _copy_functon(deep_copy): + if deep_copy: + return copy.deepcopy + else: + return lambda x: x + + +def task_details_merge(td_e, td_new, deep_copy=False): + """Merges an existing task details with a new task details object. + + The new task details fields, if they differ will replace the existing + objects fields (except name, version, uuid which can not be replaced). + + If 'deep_copy' is True, fields are copied deeply (by value) if possible. """ if td_e is td_new: return td_e + + copy_fn = _copy_functon(deep_copy) if td_e.state != td_new.state: + # NOTE(imelnikov): states are just strings, no need to copy td_e.state = td_new.state if td_e.results != td_new.results: - td_e.results = td_new.results + td_e.results = copy_fn(td_new.results) if td_e.exception != td_new.exception: - td_e.exception = td_new.exception + td_e.exception = copy_fn(td_new.exception) if td_e.stacktrace != td_new.stacktrace: - td_e.stacktrace = td_new.stacktrace + td_e.stacktrace = copy_fn(td_new.stacktrace) if td_e.meta != td_new.meta: - td_e.meta = td_new.meta + td_e.meta = copy_fn(td_new.meta) return td_e -def flow_details_merge(fd_e, fd_new): - """Merges an existing flow details with a new flow details object, the new - flow details fields, if they differ will replace the existing objects - fields (except name and uuid which can not be replaced). +def flow_details_merge(fd_e, fd_new, deep_copy=False): + """Merges an existing flow details with a new flow details object. + + The new flow details fields, if they differ will replace the existing + objects fields (except name and uuid which can not be replaced). + + If 'deep_copy' is True, fields are copied deeply (by value) if possible. """ if fd_e is fd_new: return fd_e + + copy_fn = _copy_functon(deep_copy) if fd_e.meta != fd_new.meta: - fd_e.meta = fd_new.meta + fd_e.meta = copy_fn(fd_new.meta) if fd_e.state != fd_new.state: + # NOTE(imelnikov): states are just strings, no need to copy fd_e.state = fd_new.state return fd_e -def logbook_merge(lb_e, lb_new): - """Merges an existing logbook with a new logbook object, the new logbook - fields, if they differ will replace the existing objects fields (except - name and uuid which can not be replaced). +def logbook_merge(lb_e, lb_new, deep_copy=False): + """Merges an existing logbook with a new logbook object. + + The new logbook fields, if they differ will replace the existing + objects fields (except name and uuid which can not be replaced). + + If 'deep_copy' is True, fields are copied deeply (by value) if possible. """ if lb_e is lb_new: return lb_e + + copy_fn = _copy_functon(deep_copy) if lb_e.meta != lb_new.meta: - lb_e.meta = lb_new.meta + lb_e.meta = copy_fn(lb_new.meta) return lb_e