diff --git a/ec2api/db/sqlalchemy/api.py b/ec2api/db/sqlalchemy/api.py index 984a89a0..d661fd87 100644 --- a/ec2api/db/sqlalchemy/api.py +++ b/ec2api/db/sqlalchemy/api.py @@ -134,7 +134,7 @@ def add_item_id(context, kind, os_id): item_ref.save() except db_exception.DBDuplicateEntry as ex: if (models.ITEMS_OS_ID_INDEX_NAME not in ex.columns and - 'os_id' not in ex.columns): + ex.columns != ['os_id']): raise item_ref = (model_query(context, models.Item). filter_by(os_id=os_id). @@ -248,13 +248,15 @@ def add_tags(context, tags): with session.begin(): for tag in tags: tag_ref = models.Tag(project_id=context.project_id, - item_id=tag['item_id']) - tag_ref.update(tag) + item_id=tag['item_id'], + key=tag['key'], + value=tag['value']) try: with session.begin(nested=True): tag_ref.save(session) except db_exception.DBDuplicateEntry as ex: - if 'PRIMARY' not in ex.columns: + if ('PRIMARY' not in ex.columns and + ex.columns != ['project_id', 'item_id', 'key']): raise (get_query.params(tag_item_id=tag['item_id'], tag_key=tag['key']). diff --git a/ec2api/tests/test_db_api.py b/ec2api/tests/test_db_api.py index ad04f2ef..8f83ad84 100644 --- a/ec2api/tests/test_db_api.py +++ b/ec2api/tests/test_db_api.py @@ -14,6 +14,7 @@ from oslo.config import cfg from oslotest import base as test_base +from sqlalchemy import event from sqlalchemy.orm import exc as orm_exception from ec2api.api import validator @@ -38,6 +39,21 @@ class DbApiTestCase(test_base.BaseTestCase): conf.set_override('sqlite_synchronous', False, group='database') engine = session.get_engine() + + # NOTE(ft): enable savepoints in sqlite. See SAVEPOINT support + # section in sqlalchemy.dialects.sqlite.base.py + + @event.listens_for(engine, "connect") + def do_connect(dbapi_connection, connection_record): + # disable pysqlite's emitting of the BEGIN statement entirely. + # also stops it from emitting COMMIT before any DDL. + dbapi_connection.isolation_level = None + + @event.listens_for(engine, "begin") + def do_begin(conn): + # emit our own BEGIN + conn.execute("BEGIN") + conn = engine.connect() migration.db_sync() cls.DB_SCHEMA = "".join(line @@ -266,3 +282,162 @@ class DbApiTestCase(test_base.BaseTestCase): self.assertEqual(0, len(items)) items = db_api.get_public_items(self.context, 'fake0', []) self.assertEqual(0, len(items)) + + def test_add_tags(self): + item1_id = fakes.random_ec2_id('fake') + item2_id = fakes.random_ec2_id('fake') + item3_id = fakes.random_ec2_id('fake') + tag1_01 = {'item_id': item1_id, + 'key': 'key1', + 'value': None} + tag1_1 = {'item_id': item1_id, + 'key': 'key1', + 'value': 'val'} + tag1_2 = {'item_id': item1_id, + 'key': 'key2', + 'value': 'val'} + tag1_3 = {'item_id': item1_id, + 'key': 'key3', + 'value': 'val'} + tag2_1 = {'item_id': item2_id, + 'key': 'key1', + 'value': None} + tag2_2 = {'item_id': item2_id, + 'key': 'key2', + 'value': 'val'} + tag3_1 = {'item_id': item3_id, + 'key': 'key1', + 'value': 'val'} + tag3_3 = {'item_id': item3_id, + 'key': 'key3', + 'value': 'val'} + db_api.add_tags(self.context, [tag1_01, tag2_1, + tag1_2, tag2_2]) + db_api.add_tags(self.context, [tag1_1, tag3_1, + tag1_3, tag3_3]) + tags = db_api.get_tags(self.context) + self.assertThat(tags, + matchers.ListMatches([tag1_1, tag1_2, tag1_3, + tag2_1, tag2_2, + tag3_1, tag3_3], + orderless_lists=True)) + + def test_add_tags_isolation(self): + item_id = fakes.random_ec2_id('fake') + tag1 = {'item_id': item_id, + 'key': 'key1', + 'value': 'val1'} + tag2 = {'item_id': item_id, + 'key': 'key2', + 'value': 'val2'} + db_api.add_tags(self.context, [tag1, tag2]) + db_api.add_tags(self.other_context, [{'item_id': item_id, + 'key': 'key1', + 'value': 'val1_1'}, + {'item_id': item_id, + 'key': 'key3', + 'value': 'val3'}]) + tags = db_api.get_tags(self.context) + self.assertThat(tags, matchers.ListMatches([tag1, tag2], + orderless_lists=True)) + + def test_get_tags(self): + item1_id = fakes.random_ec2_id('fake') + item2_id = fakes.random_ec2_id('fake') + item3_id = fakes.random_ec2_id('fake1') + tag1 = {'item_id': item1_id, + 'key': 'key1', + 'value': 'val1'} + tag2 = {'item_id': item2_id, + 'key': 'key2', + 'value': 'val2'} + tag3 = {'item_id': item3_id, + 'key': 'key3', + 'value': 'val3'} + db_api.add_tags(self.context, [tag1, tag2, tag3]) + + self.assertThat(db_api.get_tags(self.context), + matchers.ListMatches([tag1, tag2, tag3], + orderless_lists=True)) + self.assertThat(db_api.get_tags(self.context, ('fake',)), + matchers.ListMatches([tag1, tag2], + orderless_lists=True)) + self.assertThat(db_api.get_tags(self.context, ('fake',), + [item1_id, item2_id]), + matchers.ListMatches([tag1, tag2], + orderless_lists=True)) + self.assertThat(db_api.get_tags(self.context, ('fake',), (item1_id,)), + matchers.ListMatches([tag1], + orderless_lists=True)) + self.assertThat(db_api.get_tags(self.context, ('fake',), (item3_id,)), + matchers.ListMatches([])) + self.assertThat(db_api.get_tags(self.context, + item_ids=(item1_id, item3_id)), + matchers.ListMatches([tag1, tag3], + orderless_lists=True)) + self.assertThat(db_api.get_tags(self.context, ('fake', 'fake1'), + (item2_id, item3_id)), + matchers.ListMatches([tag2, tag3], + orderless_lists=True)) + + def test_delete_tags(self): + item1_id = fakes.random_ec2_id('fake') + item2_id = fakes.random_ec2_id('fake') + item3_id = fakes.random_ec2_id('fake1') + tag1_1 = {'item_id': item1_id, + 'key': 'key1', + 'value': 'val_a'} + tag1_2 = {'item_id': item1_id, + 'key': 'key2', + 'value': 'val_b'} + tag2_1 = {'item_id': item2_id, + 'key': 'key1', + 'value': 'val_c'} + tag2_2 = {'item_id': item2_id, + 'key': 'key2', + 'value': 'val_a'} + tag3_1 = {'item_id': item3_id, + 'key': 'key1', + 'value': 'val_b'} + tag3_2 = {'item_id': item3_id, + 'key': 'key2', + 'value': 'val_d'} + db_api.add_tags(self.context, [tag1_1, tag2_1, tag3_1, + tag1_2, tag2_2, tag3_2]) + + def do_check(*tag_list): + self.assertThat(db_api.get_tags(self.context), + matchers.ListMatches(tag_list, + orderless_lists=True)) + db_api.add_tags(self.context, [tag1_1, tag2_1, tag3_1, + tag1_2, tag2_2, tag3_2]) + + db_api.delete_tags(self.context, []) + do_check(tag1_1, tag1_2, tag2_1, tag2_2, tag3_1, tag3_2) + + db_api.delete_tags(self.context, [item1_id]) + do_check(tag2_1, tag2_2, tag3_1, tag3_2) + + db_api.delete_tags(self.context, [item1_id, item3_id]) + do_check(tag2_1, tag2_2) + + db_api.delete_tags(self.context, [item1_id, item2_id, item3_id], + [{'key': 'key1'}, + {'value': 'val_d'}, + {'key': 'key2', + 'value': 'val_b'}]) + do_check(tag2_2) + + def delete_tags_isolation(self): + item_id = fakes.random_ec2_id('fake') + tag1 = {'item_id': item_id, + 'key': 'key', + 'value': 'val1'} + db_api.add_tags(self.context, tag1) + tag2 = {'item_id': item_id, + 'key': 'key', + 'value': 'val2'} + db_api.add_tags(self.other_context, tag2) + db_api.delete_tags(self.context, item_id) + self.assertThat(db_api.get_tags(self.other_context), + matchers.ListMatches([tag2]))