Get db session from the context.

The aim is to use a single sqlalchemy session for an RPC request.

The context object passed to EngineAPI methods is actually an RpcContext
which contains the same data as the RequestContext. The @request_context
decorator turns this back into a RequestContext which can now have other
behaviours added to it.

RequestContext now has a lazy loaded session attribute.

Save calls on created entities need to be passed the shared session.

Change-Id: Ied4e66deaca205362b84fb698f75cc872886607d
This commit is contained in:
Steve Baker 2012-11-22 08:15:05 +13:00
parent b1787bd43e
commit 67f4f60815
6 changed files with 59 additions and 19 deletions

View File

@ -19,6 +19,7 @@ from heat.common import wsgi
from heat.openstack.common import cfg from heat.openstack.common import cfg
from heat.openstack.common import importutils from heat.openstack.common import importutils
from heat.common import utils as heat_utils from heat.common import utils as heat_utils
from heat.db import api as db_api
def generate_request_id(): def generate_request_id():
@ -64,10 +65,17 @@ class RequestContext(object):
self.owner_is_tenant = owner_is_tenant self.owner_is_tenant = owner_is_tenant
if overwrite or not hasattr(local.store, 'context'): if overwrite or not hasattr(local.store, 'context'):
self.update_store() self.update_store()
self._session = None
def update_store(self): def update_store(self):
local.store.context = self local.store.context = self
@property
def session(self):
if self._session is None:
self._session = db_api.get_session()
return self._session
def to_dict(self): def to_dict(self):
return {'auth_token': self.auth_token, return {'auth_token': self.auth_token,
'username': self.username, 'username': self.username,

View File

@ -48,6 +48,10 @@ def configure():
SQL_IDLE_TIMEOUT = cfg.CONF.sql_idle_timeout SQL_IDLE_TIMEOUT = cfg.CONF.sql_idle_timeout
def get_session():
return IMPL.get_session()
def raw_template_get(context, template_id): def raw_template_get(context, template_id):
return IMPL.raw_template_get(context, template_id) return IMPL.raw_template_get(context, template_id)

View File

@ -22,17 +22,17 @@ from heat.db.sqlalchemy.session import get_session
from heat.engine import auth from heat.engine import auth
def model_query(context, *args, **kwargs): def model_query(context, *args):
""" session = _session(context)
:param session: if present, the session to use
"""
session = kwargs.get('session') or get_session()
query = session.query(*args) query = session.query(*args)
return query return query
def _session(context):
return (context and context.session) or get_session()
def raw_template_get(context, template_id): def raw_template_get(context, template_id):
result = model_query(context, models.RawTemplate).get(template_id) result = model_query(context, models.RawTemplate).get(template_id)
@ -54,7 +54,7 @@ def raw_template_get_all(context):
def raw_template_create(context, values): def raw_template_create(context, values):
raw_template_ref = models.RawTemplate() raw_template_ref = models.RawTemplate()
raw_template_ref.update(values) raw_template_ref.update(values)
raw_template_ref.save() raw_template_ref.save(_session(context))
return raw_template_ref return raw_template_ref
@ -97,7 +97,7 @@ def resource_get_all(context):
def resource_create(context, values): def resource_create(context, values):
resource_ref = models.Resource() resource_ref = models.Resource()
resource_ref.update(values) resource_ref.update(values)
resource_ref.save() resource_ref.save(_session(context))
return resource_ref return resource_ref
@ -147,7 +147,7 @@ def stack_get_all_by_tenant(context):
def stack_create(context, values): def stack_create(context, values):
stack_ref = models.Stack() stack_ref = models.Stack()
stack_ref.update(values) stack_ref.update(values)
stack_ref.save() stack_ref.save(_session(context))
return stack_ref return stack_ref
@ -161,7 +161,7 @@ def stack_update(context, stack_id, values):
old_template_id = stack.raw_template_id old_template_id = stack.raw_template_id
stack.update(values) stack.update(values)
stack.save() stack.save(_session(context))
# When the raw_template ID changes, we delete the old template # When the raw_template ID changes, we delete the old template
# after storing the new template ID # after storing the new template ID
@ -196,13 +196,14 @@ def stack_delete(context, stack_id):
session.flush() session.flush()
def user_creds_create(values): def user_creds_create(context):
values = context.to_dict()
user_creds_ref = models.UserCreds() user_creds_ref = models.UserCreds()
user_creds_ref.update(values) user_creds_ref.update(values)
user_creds_ref.password = auth.encrypt(values['password']) user_creds_ref.password = auth.encrypt(values['password'])
user_creds_ref.service_password = auth.encrypt(values['service_password']) user_creds_ref.service_password = auth.encrypt(values['service_password'])
user_creds_ref.aws_creds = auth.encrypt(values['aws_creds']) user_creds_ref.aws_creds = auth.encrypt(values['aws_creds'])
user_creds_ref.save() user_creds_ref.save(_session(context))
return user_creds_ref return user_creds_ref
@ -250,7 +251,7 @@ def event_get_all_by_stack(context, stack_id):
def event_create(context, values): def event_create(context, values):
event_ref = models.Event() event_ref = models.Event()
event_ref.update(values) event_ref.update(values)
event_ref.save() event_ref.save(_session(context))
return event_ref return event_ref
@ -280,7 +281,7 @@ def watch_rule_get_all_by_stack(context, stack_id):
def watch_rule_create(context, values): def watch_rule_create(context, values):
obj_ref = models.WatchRule() obj_ref = models.WatchRule()
obj_ref.update(values) obj_ref.update(values)
obj_ref.save() obj_ref.save(_session(context))
return obj_ref return obj_ref
@ -292,7 +293,7 @@ def watch_rule_update(context, watch_id, values):
(watch_id, 'that does not exist')) (watch_id, 'that does not exist'))
wr.update(values) wr.update(values)
wr.save() wr.save(_session(context))
def watch_rule_delete(context, watch_name): def watch_rule_delete(context, watch_name):
@ -315,7 +316,7 @@ def watch_rule_delete(context, watch_name):
def watch_data_create(context, values): def watch_data_create(context, values):
obj_ref = models.WatchData() obj_ref = models.WatchData()
obj_ref.update(values) obj_ref.update(values)
obj_ref.save() obj_ref.save(_session(context))
return obj_ref return obj_ref

View File

@ -114,11 +114,11 @@ class Stack(object):
Store the stack in the database and return its ID Store the stack in the database and return its ID
If self.id is set, we update the existing stack If self.id is set, we update the existing stack
''' '''
new_creds = db_api.user_creds_create(self.context.to_dict()) new_creds = db_api.user_creds_create(self.context)
s = { s = {
'name': self.name, 'name': self.name,
'raw_template_id': self.t.store(), 'raw_template_id': self.t.store(self.context),
'parameters': self.parameters.user_parameters(), 'parameters': self.parameters.user_parameters(),
'owner_id': owner and owner.id, 'owner_id': owner and owner.id,
'user_creds_id': new_creds.id, 'user_creds_id': new_creds.id,

View File

@ -303,7 +303,7 @@ class Resource(object):
self.resource_id = inst self.resource_id = inst
if self.id is not None: if self.id is not None:
try: try:
rs = db_api.resource_get(self.stack.context, self.id) rs = db_api.resource_get(self.context, self.id)
rs.update_and_save({'nova_instance': self.resource_id}) rs.update_and_save({'nova_instance': self.resource_id})
except Exception as ex: except Exception as ex:
logger.warn('db error %s' % str(ex)) logger.warn('db error %s' % str(ex))

View File

@ -13,6 +13,7 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import functools
import webob import webob
from heat.common import context from heat.common import context
@ -32,6 +33,15 @@ from heat.openstack.common.rpc import service
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def request_context(func):
@functools.wraps(func)
def wrapped(self, ctx, *args, **kwargs):
if ctx is not None and not isinstance(ctx, context.RequestContext):
ctx = context.RequestContext.from_dict(ctx.to_dict())
return func(self, ctx, *args, **kwargs)
return wrapped
class EngineService(service.Service): class EngineService(service.Service):
""" """
Manages the running instances from creation to destruction. Manages the running instances from creation to destruction.
@ -96,6 +106,7 @@ class EngineService(service.Service):
context=stack_context, context=stack_context,
sid=s.id) sid=s.id)
@request_context
def identify_stack(self, context, stack_name): def identify_stack(self, context, stack_name):
""" """
The identify_stack method returns the full stack identifier for a The identify_stack method returns the full stack identifier for a
@ -129,6 +140,7 @@ class EngineService(service.Service):
return s return s
@request_context
def show_stack(self, context, stack_identity): def show_stack(self, context, stack_identity):
""" """
The show_stack method returns the attributes of one stack. The show_stack method returns the attributes of one stack.
@ -146,6 +158,7 @@ class EngineService(service.Service):
return {'stacks': [format_stack_detail(s) for s in stacks]} return {'stacks': [format_stack_detail(s) for s in stacks]}
@request_context
def list_stacks(self, context): def list_stacks(self, context):
""" """
The list_stacks method returns attributes of all stacks. The list_stacks method returns attributes of all stacks.
@ -159,6 +172,7 @@ class EngineService(service.Service):
return {'stacks': [format_stack_detail(s) for s in stacks]} return {'stacks': [format_stack_detail(s) for s in stacks]}
@request_context
def create_stack(self, context, stack_name, template, params, args): def create_stack(self, context, stack_name, template, params, args):
""" """
The create_stack method creates a new stack using the template The create_stack method creates a new stack using the template
@ -201,6 +215,7 @@ class EngineService(service.Service):
return dict(stack.identifier()) return dict(stack.identifier())
@request_context
def update_stack(self, context, stack_identity, template, params, args): def update_stack(self, context, stack_identity, template, params, args):
""" """
The update_stack method updates an existing stack based on the The update_stack method updates an existing stack based on the
@ -240,6 +255,7 @@ class EngineService(service.Service):
return dict(current_stack.identifier()) return dict(current_stack.identifier())
@request_context
def validate_template(self, context, template): def validate_template(self, context, template):
""" """
The validate_template method uses the stack parser to check The validate_template method uses the stack parser to check
@ -282,6 +298,7 @@ class EngineService(service.Service):
} }
return result return result
@request_context
def get_template(self, context, stack_identity): def get_template(self, context, stack_identity):
""" """
Get the template. Get the template.
@ -293,6 +310,7 @@ class EngineService(service.Service):
return s.raw_template.template return s.raw_template.template
return None return None
@request_context
def delete_stack(self, context, stack_identity): def delete_stack(self, context, stack_identity):
""" """
The delete_stack method deletes a given stack. The delete_stack method deletes a given stack.
@ -313,6 +331,7 @@ class EngineService(service.Service):
self.tg.add_thread(stack.delete) self.tg.add_thread(stack.delete)
return None return None
@request_context
def list_events(self, context, stack_identity): def list_events(self, context, stack_identity):
""" """
The list_events method lists all events associated with a given stack. The list_events method lists all events associated with a given stack.
@ -328,6 +347,7 @@ class EngineService(service.Service):
return {'events': [api.format_event(context, e) for e in events]} return {'events': [api.format_event(context, e) for e in events]}
@request_context
def describe_stack_resource(self, context, stack_identity, resource_name): def describe_stack_resource(self, context, stack_identity, resource_name):
s = self._get_stack(context, stack_identity) s = self._get_stack(context, stack_identity)
@ -341,6 +361,7 @@ class EngineService(service.Service):
return api.format_stack_resource(stack[resource_name]) return api.format_stack_resource(stack[resource_name])
@request_context
def describe_stack_resources(self, context, stack_identity, def describe_stack_resources(self, context, stack_identity,
physical_resource_id, logical_resource_id): physical_resource_id, logical_resource_id):
if stack_identity is not None: if stack_identity is not None:
@ -367,6 +388,7 @@ class EngineService(service.Service):
for resource in stack if resource.id is not None and for resource in stack if resource.id is not None and
name_match(resource)] name_match(resource)]
@request_context
def list_stack_resources(self, context, stack_identity): def list_stack_resources(self, context, stack_identity):
s = self._get_stack(context, stack_identity) s = self._get_stack(context, stack_identity)
@ -375,6 +397,7 @@ class EngineService(service.Service):
return [api.format_stack_resource(resource, detail=False) return [api.format_stack_resource(resource, detail=False)
for resource in stack if resource.id is not None] for resource in stack if resource.id is not None]
@request_context
def metadata_update(self, context, stack_id, resource_name, metadata): def metadata_update(self, context, stack_id, resource_name, metadata):
""" """
Update the metadata for the given resource. Update the metadata for the given resource.
@ -406,6 +429,7 @@ class EngineService(service.Service):
rule = watchrule.WatchRule.load(context, watch=wr) rule = watchrule.WatchRule.load(context, watch=wr)
rule.evaluate() rule.evaluate()
@request_context
def create_watch_data(self, context, watch_name, stats_data): def create_watch_data(self, context, watch_name, stats_data):
''' '''
This could be used by CloudWatch and WaitConditions This could be used by CloudWatch and WaitConditions
@ -416,6 +440,7 @@ class EngineService(service.Service):
logger.debug('new watch:%s data:%s' % (watch_name, str(stats_data))) logger.debug('new watch:%s data:%s' % (watch_name, str(stats_data)))
return stats_data return stats_data
@request_context
def show_watch(self, context, watch_name): def show_watch(self, context, watch_name):
''' '''
The show_watch method returns the attributes of one watch/alarm The show_watch method returns the attributes of one watch/alarm
@ -435,6 +460,7 @@ class EngineService(service.Service):
result = [api.format_watch(w) for w in wrs] result = [api.format_watch(w) for w in wrs]
return result return result
@request_context
def show_watch_metric(self, context, namespace=None, metric_name=None): def show_watch_metric(self, context, namespace=None, metric_name=None):
''' '''
The show_watch method returns the datapoints for a metric The show_watch method returns the datapoints for a metric
@ -459,6 +485,7 @@ class EngineService(service.Service):
result = [api.format_watch_data(w) for w in wds] result = [api.format_watch_data(w) for w in wds]
return result return result
@request_context
def set_watch_state(self, context, watch_name, state): def set_watch_state(self, context, watch_name, state):
''' '''
Temporarily set the state of a given watch Temporarily set the state of a given watch