Get db session from the context.

The aim is to use a single sqlalchemy session for an RPC request.

The context object passed to EngineAPI methods is actually an RpcContext
which contains the same data as the RequestContext. The @request_context
decorator turns this back into a RequestContext which can now have other
behaviours added to it.

RequestContext now has a lazy loaded session attribute.

Save calls on created entities need to be passed the shared session.

Change-Id: Ied4e66deaca205362b84fb698f75cc872886607d
This commit is contained in:
Steve Baker 2012-11-22 08:15:05 +13:00
parent b1787bd43e
commit 67f4f60815
6 changed files with 59 additions and 19 deletions

View File

@ -19,6 +19,7 @@ from heat.common import wsgi
from heat.openstack.common import cfg
from heat.openstack.common import importutils
from heat.common import utils as heat_utils
from heat.db import api as db_api
def generate_request_id():
@ -64,10 +65,17 @@ class RequestContext(object):
self.owner_is_tenant = owner_is_tenant
if overwrite or not hasattr(local.store, 'context'):
self.update_store()
self._session = None
def update_store(self):
local.store.context = self
@property
def session(self):
if self._session is None:
self._session = db_api.get_session()
return self._session
def to_dict(self):
return {'auth_token': self.auth_token,
'username': self.username,

View File

@ -48,6 +48,10 @@ def configure():
SQL_IDLE_TIMEOUT = cfg.CONF.sql_idle_timeout
def get_session():
return IMPL.get_session()
def raw_template_get(context, template_id):
return IMPL.raw_template_get(context, template_id)

View File

@ -22,17 +22,17 @@ from heat.db.sqlalchemy.session import get_session
from heat.engine import auth
def model_query(context, *args, **kwargs):
"""
:param session: if present, the session to use
"""
session = kwargs.get('session') or get_session()
def model_query(context, *args):
session = _session(context)
query = session.query(*args)
return query
def _session(context):
return (context and context.session) or get_session()
def raw_template_get(context, template_id):
result = model_query(context, models.RawTemplate).get(template_id)
@ -54,7 +54,7 @@ def raw_template_get_all(context):
def raw_template_create(context, values):
raw_template_ref = models.RawTemplate()
raw_template_ref.update(values)
raw_template_ref.save()
raw_template_ref.save(_session(context))
return raw_template_ref
@ -97,7 +97,7 @@ def resource_get_all(context):
def resource_create(context, values):
resource_ref = models.Resource()
resource_ref.update(values)
resource_ref.save()
resource_ref.save(_session(context))
return resource_ref
@ -147,7 +147,7 @@ def stack_get_all_by_tenant(context):
def stack_create(context, values):
stack_ref = models.Stack()
stack_ref.update(values)
stack_ref.save()
stack_ref.save(_session(context))
return stack_ref
@ -161,7 +161,7 @@ def stack_update(context, stack_id, values):
old_template_id = stack.raw_template_id
stack.update(values)
stack.save()
stack.save(_session(context))
# When the raw_template ID changes, we delete the old template
# after storing the new template ID
@ -196,13 +196,14 @@ def stack_delete(context, stack_id):
session.flush()
def user_creds_create(values):
def user_creds_create(context):
values = context.to_dict()
user_creds_ref = models.UserCreds()
user_creds_ref.update(values)
user_creds_ref.password = auth.encrypt(values['password'])
user_creds_ref.service_password = auth.encrypt(values['service_password'])
user_creds_ref.aws_creds = auth.encrypt(values['aws_creds'])
user_creds_ref.save()
user_creds_ref.save(_session(context))
return user_creds_ref
@ -250,7 +251,7 @@ def event_get_all_by_stack(context, stack_id):
def event_create(context, values):
event_ref = models.Event()
event_ref.update(values)
event_ref.save()
event_ref.save(_session(context))
return event_ref
@ -280,7 +281,7 @@ def watch_rule_get_all_by_stack(context, stack_id):
def watch_rule_create(context, values):
obj_ref = models.WatchRule()
obj_ref.update(values)
obj_ref.save()
obj_ref.save(_session(context))
return obj_ref
@ -292,7 +293,7 @@ def watch_rule_update(context, watch_id, values):
(watch_id, 'that does not exist'))
wr.update(values)
wr.save()
wr.save(_session(context))
def watch_rule_delete(context, watch_name):
@ -315,7 +316,7 @@ def watch_rule_delete(context, watch_name):
def watch_data_create(context, values):
obj_ref = models.WatchData()
obj_ref.update(values)
obj_ref.save()
obj_ref.save(_session(context))
return obj_ref

View File

@ -114,11 +114,11 @@ class Stack(object):
Store the stack in the database and return its ID
If self.id is set, we update the existing stack
'''
new_creds = db_api.user_creds_create(self.context.to_dict())
new_creds = db_api.user_creds_create(self.context)
s = {
'name': self.name,
'raw_template_id': self.t.store(),
'raw_template_id': self.t.store(self.context),
'parameters': self.parameters.user_parameters(),
'owner_id': owner and owner.id,
'user_creds_id': new_creds.id,

View File

@ -303,7 +303,7 @@ class Resource(object):
self.resource_id = inst
if self.id is not None:
try:
rs = db_api.resource_get(self.stack.context, self.id)
rs = db_api.resource_get(self.context, self.id)
rs.update_and_save({'nova_instance': self.resource_id})
except Exception as ex:
logger.warn('db error %s' % str(ex))

View File

@ -13,6 +13,7 @@
# License for the specific language governing permissions and limitations
# under the License.
import functools
import webob
from heat.common import context
@ -32,6 +33,15 @@ from heat.openstack.common.rpc import service
logger = logging.getLogger(__name__)
def request_context(func):
@functools.wraps(func)
def wrapped(self, ctx, *args, **kwargs):
if ctx is not None and not isinstance(ctx, context.RequestContext):
ctx = context.RequestContext.from_dict(ctx.to_dict())
return func(self, ctx, *args, **kwargs)
return wrapped
class EngineService(service.Service):
"""
Manages the running instances from creation to destruction.
@ -96,6 +106,7 @@ class EngineService(service.Service):
context=stack_context,
sid=s.id)
@request_context
def identify_stack(self, context, stack_name):
"""
The identify_stack method returns the full stack identifier for a
@ -129,6 +140,7 @@ class EngineService(service.Service):
return s
@request_context
def show_stack(self, context, stack_identity):
"""
The show_stack method returns the attributes of one stack.
@ -146,6 +158,7 @@ class EngineService(service.Service):
return {'stacks': [format_stack_detail(s) for s in stacks]}
@request_context
def list_stacks(self, context):
"""
The list_stacks method returns attributes of all stacks.
@ -159,6 +172,7 @@ class EngineService(service.Service):
return {'stacks': [format_stack_detail(s) for s in stacks]}
@request_context
def create_stack(self, context, stack_name, template, params, args):
"""
The create_stack method creates a new stack using the template
@ -201,6 +215,7 @@ class EngineService(service.Service):
return dict(stack.identifier())
@request_context
def update_stack(self, context, stack_identity, template, params, args):
"""
The update_stack method updates an existing stack based on the
@ -240,6 +255,7 @@ class EngineService(service.Service):
return dict(current_stack.identifier())
@request_context
def validate_template(self, context, template):
"""
The validate_template method uses the stack parser to check
@ -282,6 +298,7 @@ class EngineService(service.Service):
}
return result
@request_context
def get_template(self, context, stack_identity):
"""
Get the template.
@ -293,6 +310,7 @@ class EngineService(service.Service):
return s.raw_template.template
return None
@request_context
def delete_stack(self, context, stack_identity):
"""
The delete_stack method deletes a given stack.
@ -313,6 +331,7 @@ class EngineService(service.Service):
self.tg.add_thread(stack.delete)
return None
@request_context
def list_events(self, context, stack_identity):
"""
The list_events method lists all events associated with a given stack.
@ -328,6 +347,7 @@ class EngineService(service.Service):
return {'events': [api.format_event(context, e) for e in events]}
@request_context
def describe_stack_resource(self, context, stack_identity, resource_name):
s = self._get_stack(context, stack_identity)
@ -341,6 +361,7 @@ class EngineService(service.Service):
return api.format_stack_resource(stack[resource_name])
@request_context
def describe_stack_resources(self, context, stack_identity,
physical_resource_id, logical_resource_id):
if stack_identity is not None:
@ -367,6 +388,7 @@ class EngineService(service.Service):
for resource in stack if resource.id is not None and
name_match(resource)]
@request_context
def list_stack_resources(self, context, stack_identity):
s = self._get_stack(context, stack_identity)
@ -375,6 +397,7 @@ class EngineService(service.Service):
return [api.format_stack_resource(resource, detail=False)
for resource in stack if resource.id is not None]
@request_context
def metadata_update(self, context, stack_id, resource_name, metadata):
"""
Update the metadata for the given resource.
@ -406,6 +429,7 @@ class EngineService(service.Service):
rule = watchrule.WatchRule.load(context, watch=wr)
rule.evaluate()
@request_context
def create_watch_data(self, context, watch_name, stats_data):
'''
This could be used by CloudWatch and WaitConditions
@ -416,6 +440,7 @@ class EngineService(service.Service):
logger.debug('new watch:%s data:%s' % (watch_name, str(stats_data)))
return stats_data
@request_context
def show_watch(self, context, watch_name):
'''
The show_watch method returns the attributes of one watch/alarm
@ -435,6 +460,7 @@ class EngineService(service.Service):
result = [api.format_watch(w) for w in wrs]
return result
@request_context
def show_watch_metric(self, context, namespace=None, metric_name=None):
'''
The show_watch method returns the datapoints for a metric
@ -459,6 +485,7 @@ class EngineService(service.Service):
result = [api.format_watch_data(w) for w in wds]
return result
@request_context
def set_watch_state(self, context, watch_name, state):
'''
Temporarily set the state of a given watch