Merge "Get db session from the context."

This commit is contained in:
Jenkins 2012-11-21 19:46:31 +00:00 committed by Gerrit Code Review
commit 18010f923b
6 changed files with 59 additions and 19 deletions

View File

@ -19,6 +19,7 @@ from heat.common import wsgi
from heat.openstack.common import cfg
from heat.openstack.common import importutils
from heat.common import utils as heat_utils
from heat.db import api as db_api
def generate_request_id():
@ -64,10 +65,17 @@ class RequestContext(object):
self.owner_is_tenant = owner_is_tenant
if overwrite or not hasattr(local.store, 'context'):
self.update_store()
self._session = None
def update_store(self):
local.store.context = self
@property
def session(self):
if self._session is None:
self._session = db_api.get_session()
return self._session
def to_dict(self):
return {'auth_token': self.auth_token,
'username': self.username,

View File

@ -48,6 +48,10 @@ def configure():
SQL_IDLE_TIMEOUT = cfg.CONF.sql_idle_timeout
def get_session():
return IMPL.get_session()
def raw_template_get(context, template_id):
return IMPL.raw_template_get(context, template_id)

View File

@ -22,17 +22,17 @@ from heat.db.sqlalchemy.session import get_session
from heat.engine import auth
def model_query(context, *args, **kwargs):
"""
:param session: if present, the session to use
"""
session = kwargs.get('session') or get_session()
def model_query(context, *args):
session = _session(context)
query = session.query(*args)
return query
def _session(context):
return (context and context.session) or get_session()
def raw_template_get(context, template_id):
result = model_query(context, models.RawTemplate).get(template_id)
@ -54,7 +54,7 @@ def raw_template_get_all(context):
def raw_template_create(context, values):
raw_template_ref = models.RawTemplate()
raw_template_ref.update(values)
raw_template_ref.save()
raw_template_ref.save(_session(context))
return raw_template_ref
@ -97,7 +97,7 @@ def resource_get_all(context):
def resource_create(context, values):
resource_ref = models.Resource()
resource_ref.update(values)
resource_ref.save()
resource_ref.save(_session(context))
return resource_ref
@ -152,7 +152,7 @@ def stack_get_all_by_tenant(context):
def stack_create(context, values):
stack_ref = models.Stack()
stack_ref.update(values)
stack_ref.save()
stack_ref.save(_session(context))
return stack_ref
@ -166,7 +166,7 @@ def stack_update(context, stack_id, values):
old_template_id = stack.raw_template_id
stack.update(values)
stack.save()
stack.save(_session(context))
# When the raw_template ID changes, we delete the old template
# after storing the new template ID
@ -201,13 +201,14 @@ def stack_delete(context, stack_id):
session.flush()
def user_creds_create(values):
def user_creds_create(context):
values = context.to_dict()
user_creds_ref = models.UserCreds()
user_creds_ref.update(values)
user_creds_ref.password = auth.encrypt(values['password'])
user_creds_ref.service_password = auth.encrypt(values['service_password'])
user_creds_ref.aws_creds = auth.encrypt(values['aws_creds'])
user_creds_ref.save()
user_creds_ref.save(_session(context))
return user_creds_ref
@ -255,7 +256,7 @@ def event_get_all_by_stack(context, stack_id):
def event_create(context, values):
event_ref = models.Event()
event_ref.update(values)
event_ref.save()
event_ref.save(_session(context))
return event_ref
@ -285,7 +286,7 @@ def watch_rule_get_all_by_stack(context, stack_id):
def watch_rule_create(context, values):
obj_ref = models.WatchRule()
obj_ref.update(values)
obj_ref.save()
obj_ref.save(_session(context))
return obj_ref
@ -297,7 +298,7 @@ def watch_rule_update(context, watch_id, values):
(watch_id, 'that does not exist'))
wr.update(values)
wr.save()
wr.save(_session(context))
def watch_rule_delete(context, watch_name):
@ -320,7 +321,7 @@ def watch_rule_delete(context, watch_name):
def watch_data_create(context, values):
obj_ref = models.WatchData()
obj_ref.update(values)
obj_ref.save()
obj_ref.save(_session(context))
return obj_ref

View File

@ -114,11 +114,11 @@ class Stack(object):
Store the stack in the database and return its ID
If self.id is set, we update the existing stack
'''
new_creds = db_api.user_creds_create(self.context.to_dict())
new_creds = db_api.user_creds_create(self.context)
s = {
'name': self.name,
'raw_template_id': self.t.store(),
'raw_template_id': self.t.store(self.context),
'parameters': self.parameters.user_parameters(),
'owner_id': owner and owner.id,
'user_creds_id': new_creds.id,

View File

@ -303,7 +303,7 @@ class Resource(object):
self.resource_id = inst
if self.id is not None:
try:
rs = db_api.resource_get(self.stack.context, self.id)
rs = db_api.resource_get(self.context, self.id)
rs.update_and_save({'nova_instance': self.resource_id})
except Exception as ex:
logger.warn('db error %s' % str(ex))

View File

@ -13,6 +13,7 @@
# License for the specific language governing permissions and limitations
# under the License.
import functools
import webob
from heat.common import context
@ -32,6 +33,15 @@ from heat.openstack.common.rpc import service
logger = logging.getLogger(__name__)
def request_context(func):
@functools.wraps(func)
def wrapped(self, ctx, *args, **kwargs):
if ctx is not None and not isinstance(ctx, context.RequestContext):
ctx = context.RequestContext.from_dict(ctx.to_dict())
return func(self, ctx, *args, **kwargs)
return wrapped
class EngineService(service.Service):
"""
Manages the running instances from creation to destruction.
@ -90,6 +100,7 @@ class EngineService(service.Service):
self._periodic_watcher_task,
sid=s.id)
@request_context
def identify_stack(self, context, stack_name):
"""
The identify_stack method returns the full stack identifier for a
@ -123,6 +134,7 @@ class EngineService(service.Service):
return s
@request_context
def show_stack(self, context, stack_identity):
"""
The show_stack method returns the attributes of one stack.
@ -140,6 +152,7 @@ class EngineService(service.Service):
return {'stacks': [format_stack_detail(s) for s in stacks]}
@request_context
def list_stacks(self, context):
"""
The list_stacks method returns attributes of all stacks.
@ -153,6 +166,7 @@ class EngineService(service.Service):
return {'stacks': [format_stack_detail(s) for s in stacks]}
@request_context
def create_stack(self, context, stack_name, template, params, args):
"""
The create_stack method creates a new stack using the template
@ -194,6 +208,7 @@ class EngineService(service.Service):
return dict(stack.identifier())
@request_context
def update_stack(self, context, stack_identity, template, params, args):
"""
The update_stack method updates an existing stack based on the
@ -233,6 +248,7 @@ class EngineService(service.Service):
return dict(current_stack.identifier())
@request_context
def validate_template(self, context, template):
"""
The validate_template method uses the stack parser to check
@ -275,6 +291,7 @@ class EngineService(service.Service):
}
return result
@request_context
def get_template(self, context, stack_identity):
"""
Get the template.
@ -286,6 +303,7 @@ class EngineService(service.Service):
return s.raw_template.template
return None
@request_context
def delete_stack(self, context, stack_identity):
"""
The delete_stack method deletes a given stack.
@ -306,6 +324,7 @@ class EngineService(service.Service):
self.tg.add_thread(stack.delete)
return None
@request_context
def list_events(self, context, stack_identity):
"""
The list_events method lists all events associated with a given stack.
@ -321,6 +340,7 @@ class EngineService(service.Service):
return {'events': [api.format_event(context, e) for e in events]}
@request_context
def describe_stack_resource(self, context, stack_identity, resource_name):
s = self._get_stack(context, stack_identity)
@ -334,6 +354,7 @@ class EngineService(service.Service):
return api.format_stack_resource(stack[resource_name])
@request_context
def describe_stack_resources(self, context, stack_identity,
physical_resource_id, logical_resource_id):
if stack_identity is not None:
@ -360,6 +381,7 @@ class EngineService(service.Service):
for resource in stack if resource.id is not None and
name_match(resource)]
@request_context
def list_stack_resources(self, context, stack_identity):
s = self._get_stack(context, stack_identity)
@ -368,6 +390,7 @@ class EngineService(service.Service):
return [api.format_stack_resource(resource, detail=False)
for resource in stack if resource.id is not None]
@request_context
def metadata_update(self, context, stack_id, resource_name, metadata):
"""
Update the metadata for the given resource.
@ -418,6 +441,7 @@ class EngineService(service.Service):
rule = watchrule.WatchRule.load(stack_context, watch=wr)
rule.evaluate()
@request_context
def create_watch_data(self, context, watch_name, stats_data):
'''
This could be used by CloudWatch and WaitConditions
@ -428,6 +452,7 @@ class EngineService(service.Service):
logger.debug('new watch:%s data:%s' % (watch_name, str(stats_data)))
return stats_data
@request_context
def show_watch(self, context, watch_name):
'''
The show_watch method returns the attributes of one watch/alarm
@ -447,6 +472,7 @@ class EngineService(service.Service):
result = [api.format_watch(w) for w in wrs]
return result
@request_context
def show_watch_metric(self, context, namespace=None, metric_name=None):
'''
The show_watch method returns the datapoints for a metric
@ -471,6 +497,7 @@ class EngineService(service.Service):
result = [api.format_watch_data(w) for w in wds]
return result
@request_context
def set_watch_state(self, context, watch_name, state):
'''
Temporarily set the state of a given watch