Merge "Introduce oslo.versionedobjects to smaug"

This commit is contained in:
Jenkins 2016-02-01 11:27:08 +00:00 committed by Gerrit Code Review
commit 10a5ac9c78
16 changed files with 743 additions and 2 deletions

View File

@ -17,6 +17,7 @@ oslo.middleware>=3.0.0 # Apache-2.0
oslo.policy>=0.5.0 # Apache-2.0
oslo.serialization>=1.10.0 # Apache-2.0
oslo.service>=1.0.0 # Apache-2.0
oslo.versionedobjects>=1.4.0 # Apache-2.0
Paste
PasteDeploy>=1.5.0
requests>=2.8.1

View File

@ -25,7 +25,7 @@ from oslo_log import log as logging
from smaug.common import config # noqa
from smaug import i18n
i18n.enable_lazy()
from smaug import objects
from smaug import rpc
from smaug import service
from smaug import version
@ -35,7 +35,7 @@ CONF = cfg.CONF
def main():
objects.register_all()
CONF(sys.argv[1:], project='smaug',
version=version.version_string())
logging.setup(CONF, "smaug")

View File

@ -35,6 +35,7 @@ from smaug import db
from smaug.db import migration as db_migration
from smaug.db.sqlalchemy import api as db_api
from smaug.i18n import _
from smaug import objects
from smaug import utils
from smaug import version
@ -205,6 +206,7 @@ def fetch_func_args(func):
def main():
"""Parse options and call the appropriate class/method."""
objects.register_all()
CONF.register_cli_opt(category_opt)
script_name = sys.argv[0]
if len(sys.argv) < 2:

View File

@ -23,6 +23,7 @@ from oslo_log import log as logging
from smaug import i18n
i18n.enable_lazy()
from smaug import objects
# Need to register global_opts
from smaug.common import config # noqa
@ -34,6 +35,7 @@ CONF = cfg.CONF
def main():
objects.register_all()
CONF(sys.argv[1:], project='smaug',
version=version.version_string())
logging.setup(CONF, "smaug")

View File

@ -23,6 +23,7 @@ from oslo_log import log as logging
from smaug import i18n
i18n.enable_lazy()
from smaug import objects
# Need to register global_opts
from smaug.common import config # noqa
@ -34,6 +35,7 @@ CONF = cfg.CONF
def main():
objects.register_all()
CONF(sys.argv[1:], project='smaug',
version=version.version_string())
logging.setup(CONF, "smaug")

View File

@ -118,3 +118,7 @@ def service_update(context, service_id, values):
"""
return IMPL.service_update(context, service_id, values)
def get_by_id(context, model, id, *args, **kwargs):
return IMPL.get_by_id(context, model, id, *args, **kwargs)

View File

@ -14,6 +14,7 @@
import functools
import re
import sys
import threading
import time
@ -39,6 +40,7 @@ options.set_defaults(CONF, connection='sqlite:///$state_path/smaug.sqlite')
_LOCK = threading.Lock()
_FACADE = None
_GET_METHODS = {}
def _create_facade_lazily():
@ -306,3 +308,21 @@ def service_update(context, service_id, values):
service_ref['updated_at'] = literal_column('updated_at')
service_ref.update(values)
return service_ref
def _get_get_method(model):
# General conversion
# Convert camel cased model name to snake format
s = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', model.__name__)
# Get method must be snake formatted model name concatenated with _get
method_name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', s).lower() + '_get'
return globals().get(method_name)
@require_context
def get_by_id(context, model, id, *args, **kwargs):
# Add get method to cache dictionary if it's not already there
if not _GET_METHODS.get(model):
_GET_METHODS[model] = _get_get_method(model)
return _GET_METHODS[model](context, id, *args, **kwargs)

View File

@ -22,6 +22,7 @@ import sys
from oslo_config import cfg
from oslo_log import log as logging
from oslo_versionedobjects import exception as obj_exc
import six
import webob.exc
from webob.util import status_generic_reasons
@ -177,3 +178,7 @@ class ServiceNotFound(NotFound):
class HostBinaryNotFound(NotFound):
message = _("Could not find binary %(binary)s on host %(host)s.")
OrphanedObjectError = obj_exc.OrphanedObjectError
ObjectActionError = obj_exc.ObjectActionError

18
smaug/objects/__init__.py Normal file
View File

@ -0,0 +1,18 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
def register_all():
# You must make sure your object gets imported in this
# function in order for it to be registered by services that may
# need to receive it via RPC.
__import__('smaug.objects.service')

203
smaug/objects/base.py Normal file
View File

@ -0,0 +1,203 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Smaug common internal object model"""
import contextlib
import datetime
from oslo_log import log as logging
from oslo_versionedobjects import base
from oslo_versionedobjects import fields
from smaug import db
from smaug.db.sqlalchemy import models
from smaug import exception
from smaug.i18n import _
from smaug import objects
LOG = logging.getLogger('object')
remotable = base.remotable
remotable_classmethod = base.remotable_classmethod
obj_make_list = base.obj_make_list
class SmaugObjectRegistry(base.VersionedObjectRegistry):
def registration_hook(self, cls, index):
setattr(objects, cls.obj_name(), cls)
# For Versioned Object Classes that have a model store the model in
# a Class attribute named model
try:
model_name = cls.obj_name()
cls.model = getattr(models, model_name)
except (ImportError, AttributeError):
pass
class SmaugObject(base.VersionedObject):
OBJ_SERIAL_NAMESPACE = 'smaug_object'
OBJ_PROJECT_NAMESPACE = 'smaug'
def smaug_obj_get_changes(self):
"""Returns a dict of changed fields with tz unaware datetimes.
Any timezone aware datetime field will be converted to UTC timezone
and returned as timezone unaware datetime.
This will allow us to pass these fields directly to a db update
method as they can't have timezone information.
"""
# Get dirtied/changed fields
changes = self.obj_get_changes()
# Look for datetime objects that contain timezone information
for k, v in changes.items():
if isinstance(v, datetime.datetime) and v.tzinfo:
# Remove timezone information and adjust the time according to
# the timezone information's offset.
changes[k] = v.replace(tzinfo=None) - v.utcoffset()
# Return modified dict
return changes
@base.remotable_classmethod
def get_by_id(cls, context, id, *args, **kwargs):
# To get by id we need to have a model and for the model to
# have an id field
if 'id' not in cls.fields:
msg = (_('VersionedObject %s cannot retrieve object by id.') %
(cls.obj_name()))
raise NotImplementedError(msg)
model = getattr(models, cls.obj_name())
orm_obj = db.get_by_id(context, model, id, *args, **kwargs)
kargs = {}
if hasattr(cls, 'DEFAULT_EXPECTED_ATTR'):
kargs = {'expected_attrs': getattr(cls, 'DEFAULT_EXPECTED_ATTR')}
return cls._from_db_object(context, cls(context), orm_obj, **kargs)
def refresh(self):
# To refresh we need to have a model and for the model to have an id
# field
if 'id' not in self.fields:
msg = (_('VersionedObject %s cannot retrieve object by id.') %
(self.obj_name()))
raise NotImplementedError(msg)
current = self.get_by_id(self._context, self.id)
for field in self.fields:
# Only update attributes that are already set. We do not want to
# unexpectedly trigger a lazy-load.
if self.obj_attr_is_set(field):
if self[field] != current[field]:
self[field] = current[field]
self.obj_reset_changes()
def __contains__(self, name):
# We're using obj_extra_fields to provide aliases for some fields while
# in transition period. This override is to make these aliases pass
# "'foo' in obj" tests.
return name in self.obj_extra_fields or super(SmaugObject,
self).__contains__(name)
class SmaugObjectDictCompat(base.VersionedObjectDictCompat):
"""Mix-in to provide dictionary key access compat.
If an object needs to support attribute access using
dictionary items instead of object attributes, inherit
from this class. This should only be used as a temporary
measure until all callers are converted to use modern
attribute access.
NOTE(berrange) This class will eventually be deleted.
"""
def get(self, key, value=base._NotSpecifiedSentinel):
"""For backwards-compatibility with dict-based objects.
NOTE(danms): May be removed in the future.
"""
if key not in self.obj_fields:
# NOTE(jdg): There are a number of places where we rely on the
# old dictionary version and do a get(xxx, None).
# The following preserves that compatibility but in
# the future we'll remove this shim altogether so don't
# rely on it.
LOG.debug('Smaug object %(object_name)s has no '
'attribute named: %(attribute_name)s',
{'object_name': self.__class__.__name__,
'attribute_name': key})
return None
if (value != base._NotSpecifiedSentinel and
not self.obj_attr_is_set(key)):
return value
else:
try:
return getattr(self, key)
except (exception.ObjectActionError, NotImplementedError):
# Exception when haven't set a value for non-lazy
# loadable attribute, but to mimic typical dict 'get'
# behavior we should still return None
return None
class SmaugPersistentObject(object):
"""Mixin class for Persistent objects.
This adds the fields that we use in common for all persistent objects.
"""
fields = {
'created_at': fields.DateTimeField(nullable=True),
'updated_at': fields.DateTimeField(nullable=True),
'deleted_at': fields.DateTimeField(nullable=True),
'deleted': fields.BooleanField(default=False),
}
@contextlib.contextmanager
def obj_as_admin(self):
"""Context manager to make an object call as an admin.
This temporarily modifies the context embedded in an object to
be elevated() and restores it after the call completes. Example
usage:
with obj.obj_as_admin():
obj.save()
"""
if self._context is None:
raise exception.OrphanedObjectError(method='obj_as_admin',
objtype=self.obj_name())
original_context = self._context
self._context = self._context.elevated()
try:
yield
finally:
self._context = original_context
class SmaugComparableObject(base.ComparableVersionedObject):
def __eq__(self, obj):
if hasattr(obj, 'obj_to_primitive'):
return self.obj_to_primitive() == obj.obj_to_primitive()
return False
class ObjectListBase(base.ObjectListBase):
pass
class SmaugObjectSerializer(base.VersionedObjectSerializer):
OBJ_BASE_CLASS = SmaugObject

115
smaug/objects/service.py Normal file
View File

@ -0,0 +1,115 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from oslo_config import cfg
from oslo_log import log as logging
from oslo_versionedobjects import fields
from smaug import db
from smaug import exception
from smaug.i18n import _
from smaug import objects
from smaug.objects import base
CONF = cfg.CONF
LOG = logging.getLogger(__name__)
@base.SmaugObjectRegistry.register
class Service(base.SmaugPersistentObject, base.SmaugObject,
base.SmaugObjectDictCompat,
base.SmaugComparableObject):
# Version 1.0: Initial version
VERSION = '1.0'
fields = {
'id': fields.IntegerField(),
'host': fields.StringField(nullable=True),
'binary': fields.StringField(nullable=True),
'topic': fields.StringField(nullable=True),
'report_count': fields.IntegerField(default=0),
'disabled': fields.BooleanField(default=False),
'disabled_reason': fields.StringField(nullable=True),
'modified_at': fields.DateTimeField(nullable=True),
'rpc_current_version': fields.StringField(nullable=True),
'rpc_available_version': fields.StringField(nullable=True),
}
@staticmethod
def _from_db_object(context, service, db_service):
for name, field in service.fields.items():
value = db_service.get(name)
if isinstance(field, fields.IntegerField):
value = value or 0
elif isinstance(field, fields.DateTimeField):
value = value or None
service[name] = value
service._context = context
service.obj_reset_changes()
return service
@base.remotable_classmethod
def get_by_host_and_topic(cls, context, host, topic):
db_service = db.service_get_by_host_and_topic(context, host, topic)
return cls._from_db_object(context, cls(context), db_service)
@base.remotable_classmethod
def get_by_args(cls, context, host, binary_key):
db_service = db.service_get_by_args(context, host, binary_key)
return cls._from_db_object(context, cls(context), db_service)
@base.remotable
def create(self):
if self.obj_attr_is_set('id'):
raise exception.ObjectActionError(action='create',
reason=_('already created'))
updates = self.smaug_obj_get_changes()
db_service = db.service_create(self._context, updates)
self._from_db_object(self._context, self, db_service)
@base.remotable
def save(self):
updates = self.smaug_obj_get_changes()
if updates:
db.service_update(self._context, self.id, updates)
self.obj_reset_changes()
@base.remotable
def destroy(self):
with self.obj_as_admin():
db.service_destroy(self._context, self.id)
@base.SmaugObjectRegistry.register
class ServiceList(base.ObjectListBase, base.SmaugObject):
VERSION = '1.0'
fields = {
'objects': fields.ListOfObjectsField('Service'),
}
child_versions = {
'1.0': '1.0'
}
@base.remotable_classmethod
def get_all(cls, context, filters=None):
services = db.service_get_all(context, filters)
return base.obj_make_list(context, cls(context), objects.Service,
services)
@base.remotable_classmethod
def get_all_by_topic(cls, context, topic, disabled=None):
services = db.service_get_all_by_topic(context, topic,
disabled=disabled)
return base.obj_make_list(context, cls(context), objects.Service,
services)

View File

@ -0,0 +1,18 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import eventlet
from smaug import objects
eventlet.monkey_patch()
objects.register_all()

View File

@ -0,0 +1,54 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from oslo_utils import timeutils
from oslo_versionedobjects import fields
from smaug import objects
def fake_db_service(**updates):
NOW = timeutils.utcnow().replace(microsecond=0)
db_service = {
'created_at': NOW,
'updated_at': None,
'deleted_at': None,
'deleted': False,
'id': 123,
'host': 'fake-host',
'binary': 'fake-service',
'topic': 'fake-service-topic',
'report_count': 1,
'disabled': False,
'disabled_reason': None,
'modified_at': NOW,
}
for name, field in objects.Service.fields.items():
if name in db_service:
continue
if field.nullable:
db_service[name] = None
elif field.default != fields.UnspecifiedDefault:
db_service[name] = field.default
else:
raise Exception('fake_db_service needs help with %s.' % name)
if updates:
db_service.update(updates)
return db_service
def fake_service_obj(context, **updates):
return objects.Service._from_db_object(context, objects.Service(),
fake_db_service(**updates))

View File

@ -0,0 +1,43 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from oslo_utils import timeutils
from smaug import context
from smaug.objects import base as obj_base
from smaug.tests import base
class BaseObjectsTestCase(base.TestCase):
def setUp(self):
super(BaseObjectsTestCase, self).setUp()
self.user_id = 'fake-user'
self.project_id = 'fake-project'
self.context = context.RequestContext(self.user_id, self.project_id,
is_admin=False)
# We only test local right now.
self.assertIsNone(obj_base.SmaugObject.indirection_api)
@staticmethod
def _compare(test, db, obj):
for field, value in db.items():
if not hasattr(obj, field):
continue
if field in ('modified_at', 'created_at',
'updated_at', 'deleted_at') and db[field]:
test.assertEqual(db[field],
timeutils.normalize_time(obj[field]))
elif isinstance(obj[field], obj_base.ObjectListBase):
test.assertEqual(db[field], obj[field].objects)
else:
test.assertEqual(db[field], obj[field])

View File

@ -0,0 +1,154 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import datetime
import mock
import uuid
from iso8601 import iso8601
from oslo_versionedobjects import fields
from smaug import objects
from smaug.tests.unit import objects as test_objects
@objects.base.SmaugObjectRegistry.register_if(False)
class TestObject(objects.base.SmaugObject):
fields = {
'scheduled_at': objects.base.fields.DateTimeField(nullable=True),
'uuid': objects.base.fields.UUIDField(),
'text': objects.base.fields.StringField(nullable=True),
}
class TestSmaugObject(test_objects.BaseObjectsTestCase):
"""Tests methods from SmaugObject."""
def setUp(self):
super(TestSmaugObject, self).setUp()
self.obj = TestObject(
scheduled_at=None,
uuid=uuid.uuid4(),
text='text')
self.obj.obj_reset_changes()
def test_smaug_obj_get_changes_no_changes(self):
self.assertDictEqual({}, self.obj.smaug_obj_get_changes())
def test_smaug_obj_get_changes_other_changes(self):
self.obj.text = 'text2'
self.assertDictEqual({'text': 'text2'},
self.obj.smaug_obj_get_changes())
def test_smaug_obj_get_changes_datetime_no_tz(self):
now = datetime.datetime.utcnow()
self.obj.scheduled_at = now
self.assertDictEqual({'scheduled_at': now},
self.obj.smaug_obj_get_changes())
def test_smaug_obj_get_changes_datetime_tz_utc(self):
now_tz = iso8601.parse_date('2015-06-26T22:00:01Z')
now = now_tz.replace(tzinfo=None)
self.obj.scheduled_at = now_tz
self.assertDictEqual({'scheduled_at': now},
self.obj.smaug_obj_get_changes())
def test_smaug_obj_get_changes_datetime_tz_non_utc_positive(self):
now_tz = iso8601.parse_date('2015-06-26T22:00:01+01')
now = now_tz.replace(tzinfo=None) - datetime.timedelta(hours=1)
self.obj.scheduled_at = now_tz
self.assertDictEqual({'scheduled_at': now},
self.obj.smaug_obj_get_changes())
def test_smaug_obj_get_changes_datetime_tz_non_utc_negative(self):
now_tz = iso8601.parse_date('2015-06-26T10:00:01-05')
now = now_tz.replace(tzinfo=None) + datetime.timedelta(hours=5)
self.obj.scheduled_at = now_tz
self.assertDictEqual({'scheduled_at': now},
self.obj.smaug_obj_get_changes())
def test_refresh(self):
@objects.base.SmaugObjectRegistry.register_if(False)
class MyTestObject(objects.base.SmaugObject,
objects.base.SmaugObjectDictCompat,
objects.base.SmaugComparableObject):
fields = {'id': fields.UUIDField(),
'name': fields.StringField()}
test_obj = MyTestObject(id='1', name='foo')
refresh_obj = MyTestObject(id='1', name='bar')
with mock.patch(
'smaug.objects.base.SmaugObject.get_by_id') as get_by_id:
get_by_id.return_value = refresh_obj
test_obj.refresh()
self._compare(self, refresh_obj, test_obj)
def test_refresh_no_id_field(self):
@objects.base.SmaugObjectRegistry.register_if(False)
class MyTestObjectNoId(objects.base.SmaugObject,
objects.base.SmaugObjectDictCompat,
objects.base.SmaugComparableObject):
fields = {'uuid': fields.UUIDField()}
test_obj = MyTestObjectNoId(uuid='1', name='foo')
self.assertRaises(NotImplementedError, test_obj.refresh)
class TestSmaugComparableObject(test_objects.BaseObjectsTestCase):
def test_comparable_objects(self):
@objects.base.SmaugObjectRegistry.register
class MyComparableObj(objects.base.SmaugObject,
objects.base.SmaugObjectDictCompat,
objects.base.SmaugComparableObject):
fields = {'foo': fields.Field(fields.Integer())}
class NonVersionedObject(object):
pass
obj1 = MyComparableObj(foo=1)
obj2 = MyComparableObj(foo=1)
obj3 = MyComparableObj(foo=2)
obj4 = NonVersionedObject()
self.assertTrue(obj1 == obj2)
self.assertFalse(obj1 == obj3)
self.assertFalse(obj1 == obj4)
self.assertNotEqual(obj1, None)
class TestSmaugDictObject(test_objects.BaseObjectsTestCase):
@objects.base.SmaugObjectRegistry.register_if(False)
class TestDictObject(objects.base.SmaugObjectDictCompat,
objects.base.SmaugObject):
obj_extra_fields = ['foo']
fields = {
'abc': fields.StringField(nullable=True),
'def': fields.IntegerField(nullable=True),
}
@property
def foo(self):
return 42
def test_dict_objects(self):
obj = self.TestDictObject()
self.assertIsNone(obj.get('non_existing'))
self.assertEqual('val', obj.get('abc', 'val'))
self.assertIsNone(obj.get('abc'))
obj.abc = 'val2'
self.assertEqual('val2', obj.get('abc', 'val'))
self.assertEqual(42, obj.get('foo'))
self.assertTrue('foo' in obj)
self.assertTrue('abc' in obj)
self.assertFalse('def' in obj)

View File

@ -0,0 +1,100 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import mock
from smaug import objects
from smaug.tests.unit import fake_service
from smaug.tests.unit import objects as test_objects
class TestService(test_objects.BaseObjectsTestCase):
@mock.patch('smaug.db.sqlalchemy.api.service_get')
def test_get_by_id(self, service_get):
db_service = fake_service.fake_db_service()
service_get.return_value = db_service
service = objects.Service.get_by_id(self.context, 1)
self._compare(self, db_service, service)
service_get.assert_called_once_with(self.context, 1)
@mock.patch('smaug.db.service_get_by_host_and_topic')
def test_get_by_host_and_topic(self, service_get_by_host_and_topic):
db_service = fake_service.fake_db_service()
service_get_by_host_and_topic.return_value = db_service
service = objects.Service.get_by_host_and_topic(
self.context, 'fake-host', 'fake-topic')
self._compare(self, db_service, service)
service_get_by_host_and_topic.assert_called_once_with(
self.context, 'fake-host', 'fake-topic')
@mock.patch('smaug.db.service_get_by_args')
def test_get_by_args(self, service_get_by_args):
db_service = fake_service.fake_db_service()
service_get_by_args.return_value = db_service
service = objects.Service.get_by_args(
self.context, 'fake-host', 'fake-key')
self._compare(self, db_service, service)
service_get_by_args.assert_called_once_with(
self.context, 'fake-host', 'fake-key')
@mock.patch('smaug.db.service_create')
def test_create(self, service_create):
db_service = fake_service.fake_db_service()
service_create.return_value = db_service
service = objects.Service(context=self.context)
service.create()
self.assertEqual(db_service['id'], service.id)
service_create.assert_called_once_with(self.context, {})
@mock.patch('smaug.db.service_update')
def test_save(self, service_update):
db_service = fake_service.fake_db_service()
service = objects.Service._from_db_object(
self.context, objects.Service(), db_service)
service.topic = 'foobar'
service.save()
service_update.assert_called_once_with(self.context, service.id,
{'topic': 'foobar'})
@mock.patch('smaug.db.service_destroy')
def test_destroy(self, service_destroy):
db_service = fake_service.fake_db_service()
service = objects.Service._from_db_object(
self.context, objects.Service(), db_service)
with mock.patch.object(service._context, 'elevated') as elevated_ctx:
service.destroy()
service_destroy.assert_called_once_with(elevated_ctx(), 123)
class TestServiceList(test_objects.BaseObjectsTestCase):
@mock.patch('smaug.db.service_get_all')
def test_get_all(self, service_get_all):
db_service = fake_service.fake_db_service()
service_get_all.return_value = [db_service]
services = objects.ServiceList.get_all(self.context, 'foo')
service_get_all.assert_called_once_with(self.context, 'foo')
self.assertEqual(1, len(services))
TestService._compare(self, db_service, services[0])
@mock.patch('smaug.db.service_get_all_by_topic')
def test_get_all_by_topic(self, service_get_all_by_topic):
db_service = fake_service.fake_db_service()
service_get_all_by_topic.return_value = [db_service]
services = objects.ServiceList.get_all_by_topic(
self.context, 'foo', 'bar')
service_get_all_by_topic.assert_called_once_with(
self.context, 'foo', disabled='bar')
self.assertEqual(1, len(services))
TestService._compare(self, db_service, services[0])