Fix model update flows

The controller worker flow to update the data model was simplistically
setting new attribute values to whatever made it through the API's
verification / syntax check. While this works well for simple
attributes, if the attribute happens to be a reference to another
object, or if updating that attribute were to affect some other
attribute or linked object, this would case the data model update to fail.

The root of the issue here is that the data model code does not
currently duplicate all the functionality that is handled by SQLAlchemy
for the database models. This was complicated by the fact that we have
at the present time, essentially no tests for pure data model updates
(ie. which don't hit the underlying repository / database).

In particular, this update corrects update problems for
pool.session_persistence attributes, as well as listener.default_pool_id
attribute updates.

Closes-Bug: 1544851
Change-Id: I8617dcb38013456d5ba5e10a92e25be0c07e3e8d
This commit is contained in:
Stephen Balukoff 2016-02-15 18:45:25 -08:00
parent 92046a481f
commit 24ef5c9d96
5 changed files with 138 additions and 4 deletions
octavia
common
controller/worker/tasks
db
tests
functional/db
unit/controller/worker/tasks

@ -15,6 +15,8 @@
import re
from sqlalchemy.orm import collections
class BaseDataModel(object):
@ -47,6 +49,66 @@ class BaseDataModel(object):
# Split the class name up by capitalized words
return ' '.join(re.findall('[A-Z][^A-Z]*', cls.__name__))
def _get_unique_key(self, obj=None):
"""Returns a unique key for passed object for data model building."""
obj = obj or self
# First handle all objects with their own ID, then handle subordinate
# objects.
if obj.__class__.__name__ in ['Member', 'Pool', 'LoadBalancer',
'Listener', 'Amphora']:
return obj.__class__.__name__ + obj.id
elif obj.__class__.__name__ in ['SessionPersistence', 'HealthMonitor']:
return obj.__class__.__name__ + obj.pool_id
elif obj.__class__.__name__ in ['ListenerStatistics', 'SNI']:
return obj.__class__.__name__ + obj.listener_id
elif obj.__class__.__name__ in ['VRRPGroup', 'Vip']:
return obj.__class__.__name__ + obj.load_balancer_id
elif obj.__class__.__name__ in ['AmphoraHealth']:
return obj.__class__.__name__ + obj.amphora_id
else:
raise NotImplementedError
def _find_in_graph(self, key, _visited_nodes=None):
"""Locates an object with the given unique key in the current
object graph and returns a reference to it.
"""
_visited_nodes = _visited_nodes or []
mykey = self._get_unique_key()
if mykey in _visited_nodes:
# Seen this node already, don't traverse further
return None
elif mykey == key:
return self
else:
_visited_nodes.append(mykey)
attr_names = [attr_name for attr_name in dir(self)
if not attr_name.startswith('_')]
for attr_name in attr_names:
attr = getattr(self, attr_name)
if isinstance(attr, BaseDataModel):
result = attr._find_in_graph(
key, _visited_nodes=_visited_nodes)
if result is not None:
return result
elif isinstance(attr, (collections.InstrumentedList, list)):
for item in attr:
if isinstance(item, BaseDataModel):
result = item._find_in_graph(
key, _visited_nodes=_visited_nodes)
if result is not None:
return result
# If we are here we didn't find it.
return None
def update(self, update_dict):
"""Generic update method which works for simple,
non-relational attributes.
"""
for key, value in update_dict.items():
setattr(self, key, value)
class SessionPersistence(BaseDataModel):
@ -57,6 +119,9 @@ class SessionPersistence(BaseDataModel):
self.cookie_name = cookie_name
self.pool = pool
def delete(self):
self.pool.session_persistence = None
class ListenerStatistics(BaseDataModel):
@ -70,6 +135,9 @@ class ListenerStatistics(BaseDataModel):
self.total_connections = total_connections
self.listener = listener
def delete(self):
self.listener.stats = None
class HealthMonitor(BaseDataModel):
@ -117,12 +185,31 @@ class Pool(BaseDataModel):
self.session_persistence = session_persistence
self.listeners = listeners or []
def update(self, update_dict):
for key, value in update_dict.items():
if key == 'session_persistence':
if self.session_persistence is not None:
self.session_persistence.update(value)
else:
value.update({'pool_id': self.id})
self.session_persistence = SessionPersistence(**value)
else:
setattr(self, key, value)
def delete(self):
# TODO(sbalukoff): Clean up L7Policies that reference this pool
for listener in self.listeners:
if listener.default_pool_id == self.id:
listener.default_pool = None
listener.default_pool_id = None
for pool in listener.pools:
if pool.id == self.id:
listener.pools.remove(pool)
break
for pool in self.load_balancer.pools:
if pool.id == self.id:
self.load_balancer.pools.remove(pool)
break
class Member(BaseDataModel):
@ -177,6 +264,28 @@ class Listener(BaseDataModel):
self.peer_port = peer_port
self.pools = pools or []
def update(self, update_dict):
for key, value in update_dict.items():
setattr(self, key, value)
if key == 'default_pool_id':
if value is not None:
pool = self._find_in_graph('Pool' + value)
if pool not in self.pools:
self.pools.append(pool)
if self not in pool.listeners:
pool.listeners.append(self)
else:
pool = None
setattr(self, 'default_pool', pool)
def delete(self):
for listener in self.load_balancer.listeners:
if listener.id == self.id:
self.load_balancer.listeners.remove(listener)
break
for pool in self.pools:
pool.listeners.remove(self)
class LoadBalancer(BaseDataModel):
@ -271,6 +380,12 @@ class Amphora(BaseDataModel):
self.cert_expiration = cert_expiration
self.cert_busy = cert_busy
def delete(self):
for amphora in self.load_balancer.amphorae:
if amphora.id == self.id:
self.load_balancer.amphorae.remove(amphora)
break
class AmphoraHealth(BaseDataModel):

@ -29,5 +29,4 @@ class UpdateAttributes(task.Task):
def execute(self, object, update_dict):
for key, value in update_dict.items():
setattr(object, key, value)
object.update(update_dict)

@ -188,7 +188,10 @@ class Repositories(object):
:returns: octavia.common.data_models.Pool
"""
with session.begin(subtransactions=True):
self.pool.update(session, pool_id, **pool_dict)
# If only the session persistence is being updated, this will be
# empty
if len(pool_dict.keys()) > 0:
self.pool.update(session, pool_id, **pool_dict)
if sp_dict:
if self.session_persistence.exists(session, pool_id):
self.session_persistence.update(session, pool_id,

@ -696,6 +696,23 @@ class DataModelConversionTest(base.OctaviaDBTestBase, ModelTestMixin):
load_balancer.pools[0].members[0].pool.load_balancer.id)
self.assertEqual(lb_dm.id, m_lb_id)
def test_update_data_model_listener_default_pool_id(self):
lb_dm = self.create_load_balancer(
self.session, id=uuidutils.generate_uuid()).to_data_model()
pool1_dm = self.create_pool(
self.session, id=uuidutils.generate_uuid(),
load_balancer_id=lb_dm.id).to_data_model()
pool2_dm = self.create_pool(
self.session, id=uuidutils.generate_uuid(),
load_balancer_id=lb_dm.id).to_data_model()
listener_dm = self.create_listener(
self.session, id=uuidutils.generate_uuid(),
load_balancer_id=lb_dm.id,
default_pool_id=pool1_dm.id).to_data_model()
self.assertEqual(pool1_dm.id, listener_dm.default_pool.id)
listener_dm.update({'default_pool_id': pool2_dm.id})
self.assertEqual(listener_dm.default_pool.id, pool2_dm.id)
def test_load_balancer_tree(self):
lb_db = self.session.query(models.LoadBalancer).filter_by(
id=self.lb.id).first()

@ -41,4 +41,4 @@ class TestObjectUpdateTasks(base.TestCase):
update_attr.execute(self.listener_mock,
{'name': 'TEST2'})
assert self.listener_mock.name == 'TEST2'
self.listener_mock.update.assert_called_once_with({'name': 'TEST2'})