Switch to explicit transaction management

Explicit is way better than implicit in this case. While managing
transactions requires an extra line/indent here and there, we should
strive to not depending from framework to do it for us.

Change-Id: I610b8db8fa65cc875bc848c51ccff675df38222b
This commit is contained in:
Yuriy Taraday 2016-04-28 11:03:26 +03:00
parent 0d866a17ab
commit c62ea00d30
4 changed files with 32 additions and 21 deletions

View File

@ -10,6 +10,7 @@
# License for the specific language governing permissions and limitations
# under the License.
import functools
import itertools
import flask
@ -37,6 +38,15 @@ component_fields = {
}
def with_transaction(f):
@functools.wraps(f)
def inner(*args, **kwargs):
with db.db.session.begin():
return f(*args, **kwargs)
return inner
@api.resource('/components')
class ComponentsCollection(flask_restful.Resource):
method_decorators = [flask_restful.marshal_with(component_fields)]
@ -44,6 +54,7 @@ class ComponentsCollection(flask_restful.Resource):
def get(self):
return db.Component.query.all()
@with_transaction
def post(self):
component = db.Component(name=flask.request.json['name'])
component.resource_definitions = []
@ -52,7 +63,6 @@ class ComponentsCollection(flask_restful.Resource):
content=resdef_data.get('content'))
component.resource_definitions.append(resdef)
db.db.session.add(component)
db.db.session.commit()
return component, 201
@ -63,10 +73,10 @@ class Component(flask_restful.Resource):
def get(self, component_id):
return db.Component.query.get_or_404(component_id)
@with_transaction
def delete(self, component_id):
component = db.Component.query.get_or_404(component_id)
db.db.session.delete(component)
db.db.session.commit()
return None, 204
environment_fields = {
@ -83,6 +93,7 @@ class EnvironmentsCollection(flask_restful.Resource):
def get(self):
return db.Environment.query.all()
@with_transaction
def post(self):
component_ids = flask.request.json['components']
# TODO(yorik-sar): verify that resource names do not clash
@ -100,7 +111,6 @@ class EnvironmentsCollection(flask_restful.Resource):
if 'id' in flask.request.json:
environment.id = flask.request.json['id']
db.db.session.add(environment)
db.db.session.commit()
return environment, 201
@ -111,10 +121,10 @@ class Environment(flask_restful.Resource):
def get(self, environment_id):
return db.Environment.query.get_or_404(environment_id)
@with_transaction
def delete(self, environment_id):
environment = db.Environment.query.get_or_404(environment_id)
db.db.session.delete(environment)
db.db.session.commit()
return None, 204
@ -148,6 +158,7 @@ def get_environment_level_value(environment, levels):
'/environments/<int:environment_id>' +
'/<levels:levels>resources/<id_or_name:resource_id_or_name>/values')
class ResourceValues(flask_restful.Resource):
@with_transaction
def put(self, environment_id, levels, resource_id_or_name):
environment = db.Environment.query.get_or_404(environment_id)
level_value = get_environment_level_value(environment, levels)
@ -168,9 +179,9 @@ class ResourceValues(flask_restful.Resource):
level_value=level_value,
)
esv.values = flask.request.json
db.db.session.commit()
return None, 204
@with_transaction
def get(self, environment_id, resource_id_or_name, levels):
environment = db.Environment.query.get_or_404(environment_id)
level_values = list(iter_environment_level_values(environment, levels))
@ -220,6 +231,7 @@ class ResourceValues(flask_restful.Resource):
'/environments/<int:environment_id>' +
'/<levels:levels>resources/<id_or_name:resource_id_or_name>/overrides')
class ResourceOverrides(flask_restful.Resource):
@with_transaction
def put(self, environment_id, levels, resource_id_or_name):
environment = db.Environment.query.get_or_404(environment_id)
level_value = get_environment_level_value(environment, levels)
@ -240,9 +252,9 @@ class ResourceOverrides(flask_restful.Resource):
level_value=level_value,
)
esv.overrides = flask.request.json
db.db.session.commit()
return None, 204
@with_transaction
def get(self, environment_id, resource_id_or_name, levels):
environment = db.Environment.query.get_or_404(environment_id)
level_value = get_environment_level_value(environment, levels)

View File

@ -26,7 +26,7 @@ try:
except ImportError:
pass # in 2.x reload is builtin
db = flask_sqlalchemy.SQLAlchemy()
db = flask_sqlalchemy.SQLAlchemy(session_options={'autocommit': True})
pk_type = db.Integer
pk = functools.partial(db.Column, pk_type, primary_key=True)

View File

@ -51,7 +51,7 @@ class TestApp(base.TestCase):
self.client = Client(self.app)
def _fixture(self):
with self.app.app_context():
with self.app.app_context(), db.db.session.begin():
component = db.Component(
id=7,
name='component1',
@ -70,7 +70,6 @@ class TestApp(base.TestCase):
hierarchy_levels[1].parent = hierarchy_levels[0]
environment.hierarchy_levels = hierarchy_levels
db.db.session.add(environment)
db.db.session.commit()
@property
def _component_json(self):
@ -251,7 +250,7 @@ class TestApp(base.TestCase):
def test_get_environment_level_value_root(self):
self._fixture()
with self.app.app_context():
with self.app.app_context(), db.db.session.begin():
level_value = app.get_environment_level_value(
db.Environment(id=9),
[],
@ -260,7 +259,7 @@ class TestApp(base.TestCase):
def test_get_environment_level_value_deep(self):
self._fixture()
with self.app.app_context():
with self.app.app_context(), db.db.session.begin():
level_value = app.get_environment_level_value(
db.Environment(id=9),
[('lvl1', 'val1'), ('lvl2', 'val2')],
@ -275,7 +274,7 @@ class TestApp(base.TestCase):
def test_get_environment_level_value_bad_level(self):
self._fixture()
with self.app.app_context():
with self.app.app_context(), db.db.session.begin():
exc = self.assertRaises(
exceptions.BadRequest,
app.get_environment_level_value,

View File

@ -48,13 +48,13 @@ class TestDB(_DBTestCase):
pass
def test_get_or_create_get(self):
with self.app.app_context():
with self.app.app_context(), db.db.session.begin():
db.db.session.add(db.Component(name="nsname"))
res = db.get_or_create(db.Component, name="nsname")
self.assertEqual(res.name, "nsname")
def test_get_or_create_create(self):
with self.app.app_context():
with self.app.app_context(), db.db.session.begin():
res = db.get_or_create(db.Component, name="nsname")
self.assertIsNotNone(res.id)
self.assertEqual(res.name, "nsname")
@ -100,11 +100,11 @@ class TestDBPrefixed(base.PrefixedTestCaseMixin, TestDB):
class TestEnvironmentHierarchyLevel(_DBTestCase):
def setUp(self):
super(TestEnvironmentHierarchyLevel, self).setUp()
with self.app.app_context():
with self.app.app_context(), db.db.session.begin():
session = db.db.session
environment = db.Environment()
session.add(environment)
session.commit()
session.flush()
self.environment_id = environment.id
def _create_levels(self, num):
@ -118,14 +118,14 @@ class TestEnvironmentHierarchyLevel(_DBTestCase):
)
session.add(lvl)
last_lvl = lvl
session.commit()
session.flush()
def _test_get_for_environment(self, num, expected):
with self.app.app_context():
with self.app.app_context(), db.db.session.begin():
self._create_levels(num)
res = db.EnvironmentHierarchyLevel.get_for_environment(
db.Environment(id=self.environment_id))
level_names = [level.name for level in res]
env = db.Environment(id=self.environment_id)
res = db.EnvironmentHierarchyLevel.get_for_environment(env)
level_names = [level.name for level in res]
self.assertEqual(level_names, expected)
def test_get_for_environment_empty(self):