Merge "Split endpoint records in SQL by interface"

This commit is contained in:
Jenkins 2012-12-18 18:01:40 +00:00 committed by Gerrit Code Review
commit ac2d92ca2e
9 changed files with 487 additions and 106 deletions

View File

@ -36,12 +36,14 @@ class Service(sql.ModelBase, sql.DictBase):
class Endpoint(sql.ModelBase, sql.DictBase):
__tablename__ = 'endpoint'
attributes = ['id', 'region', 'service_id']
attributes = ['id', 'interface', 'region', 'service_id', 'url']
id = sql.Column(sql.String(64), primary_key=True)
interface = sql.Column(sql.String(8), primary_key=True)
region = sql.Column('region', sql.String(255))
service_id = sql.Column(sql.String(64),
sql.ForeignKey('service.id'),
nullable=False)
url = sql.Column(sql.Text())
extra = sql.Column(sql.JsonBlob())
@ -88,7 +90,9 @@ class Catalog(sql.Base, catalog.Driver):
old_dict = ref.to_dict()
old_dict.update(service_ref)
new_service = Service.from_dict(old_dict)
ref.type = new_service.type
for attr in Service.attributes:
if attr != 'id':
setattr(ref, attr, getattr(new_service, attr))
ref.extra = new_service.extra
session.flush()
return ref.to_dict()
@ -132,8 +136,9 @@ class Catalog(sql.Base, catalog.Driver):
old_dict = ref.to_dict()
old_dict.update(endpoint_ref)
new_endpoint = Endpoint.from_dict(old_dict)
ref.service_id = new_endpoint.service_id
ref.region = new_endpoint.region
for attr in Endpoint.attributes:
if attr != 'id':
setattr(ref, attr, getattr(new_endpoint, attr))
ref.extra = new_endpoint.extra
session.flush()
return ref.to_dict()
@ -142,25 +147,28 @@ class Catalog(sql.Base, catalog.Driver):
d = dict(CONF.iteritems())
d.update({'tenant_id': tenant_id,
'user_id': user_id})
catalog = {}
services = {}
for endpoint in self.list_endpoints():
# look up the service
services.setdefault(
endpoint['service_id'],
self.get_service(endpoint['service_id']))
service = services[endpoint['service_id']]
endpoints = self.list_endpoints()
for ep in endpoints:
service = self.get_service(ep['service_id'])
srv_type = service['type']
srv_name = service['name']
region = ep['region']
# add the endpoint to the catalog if it's not already there
catalog.setdefault(endpoint['region'], {})
catalog[endpoint['region']].setdefault(
service['type'], {
'id': endpoint['id'],
'name': service['name'],
'publicURL': '', # this may be overridden, but must exist
})
if region not in catalog:
catalog[region] = {}
catalog[region][srv_type] = {}
srv_type = catalog[region][srv_type]
srv_type['id'] = ep['id']
srv_type['name'] = srv_name
srv_type['publicURL'] = core.format_url(ep.get('publicurl', ''), d)
srv_type['internalURL'] = core.format_url(ep.get('internalurl'), d)
srv_type['adminURL'] = core.format_url(ep.get('adminurl'), d)
# add the interface's url
url = core.format_url(endpoint.get('url'), d)
interface_url = '%sURL' % endpoint['interface']
catalog[endpoint['region']][service['type']][interface_url] = url
return catalog

View File

@ -20,6 +20,13 @@ import uuid
from keystone.catalog import core
from keystone.common import controller
from keystone.common import wsgi
from keystone import exception
from keystone import identity
from keystone import policy
from keystone import token
INTERFACES = ['public', 'internal', 'admin']
class Service(controller.V2Controller):
@ -50,22 +57,62 @@ class Service(controller.V2Controller):
class Endpoint(controller.V2Controller):
def get_endpoints(self, context):
"""Merge matching v3 endpoint refs into legacy refs."""
self.assert_admin(context)
endpoint_list = self.catalog_api.list_endpoints(context)
return {'endpoints': endpoint_list}
legacy_endpoints = {}
for endpoint in self.catalog_api.list_endpoints(context):
if not endpoint['legacy_endpoint_id']:
# endpoints created in v3 should not appear on the v2 API
continue
# is this is a legacy endpoint we haven't indexed yet?
if endpoint['legacy_endpoint_id'] not in legacy_endpoints:
legacy_ep = endpoint.copy()
legacy_ep['id'] = legacy_ep.pop('legacy_endpoint_id')
legacy_ep.pop('interface')
legacy_ep.pop('url')
legacy_endpoints[endpoint['legacy_endpoint_id']] = legacy_ep
else:
legacy_ep = legacy_endpoints[endpoint['legacy_endpoint_id']]
# add the legacy endpoint with an interface url
legacy_ep['%surl' % endpoint['interface']] = endpoint['url']
return {'endpoints': legacy_endpoints.values()}
def create_endpoint(self, context, endpoint):
"""Create three v3 endpoint refs based on a legacy ref."""
self.assert_admin(context)
endpoint_id = uuid.uuid4().hex
endpoint_ref = endpoint.copy()
endpoint_ref['id'] = endpoint_id
new_endpoint_ref = self.catalog_api.create_endpoint(
context, endpoint_id, endpoint_ref)
return {'endpoint': new_endpoint_ref}
legacy_endpoint_ref = endpoint.copy()
urls = dict((i, endpoint.pop('%surl' % i)) for i in INTERFACES)
legacy_endpoint_id = uuid.uuid4().hex
for interface, url in urls.iteritems():
endpoint_ref = endpoint.copy()
endpoint_ref['id'] = uuid.uuid4().hex
endpoint_ref['legacy_endpoint_id'] = legacy_endpoint_id
endpoint_ref['interface'] = interface
endpoint_ref['url'] = url
self.catalog_api.create_endpoint(
context, endpoint_ref['id'], endpoint_ref)
legacy_endpoint_ref['id'] = legacy_endpoint_id
return {'endpoint': legacy_endpoint_ref}
def delete_endpoint(self, context, endpoint_id):
"""Delete up to three v3 endpoint refs based on a legacy ref ID."""
self.assert_admin(context)
self.catalog_api.delete_endpoint(context, endpoint_id)
deleted_at_least_one = False
for endpoint in self.catalog_api.list_endpoints(context):
if endpoint['legacy_endpoint_id'] == endpoint_id:
self.catalog_api.delete_endpoint(context, endpoint['id'])
deleted_at_least_one = True
if not deleted_at_least_one:
raise exception.EndpointNotFound(endpoint_id=endpoint_id)
class ServiceV3(controller.V3Controller):

View File

@ -0,0 +1,54 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2012 OpenStack LLC
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import migrate
import sqlalchemy as sql
def upgrade(migrate_engine):
"""Create API-version specific endpoint tables."""
meta = sql.MetaData()
meta.bind = migrate_engine
legacy_table = sql.Table('endpoint', meta, autoload=True)
legacy_table.rename('endpoint_v2')
new_table = sql.Table(
'endpoint_v3',
meta,
sql.Column('id', sql.String(64), primary_key=True),
sql.Column('legacy_endpoint_id', sql.String(64)),
sql.Column('interface', sql.String(8), nullable=False),
sql.Column('region', sql.String(255)),
sql.Column('service_id',
sql.String(64),
sql.ForeignKey('service.id'),
nullable=False),
sql.Column('url', sql.Text(), nullable=False),
sql.Column('extra', sql.Text()))
new_table.create(migrate_engine, checkfirst=True)
def downgrade(migrate_engine):
"""Replace API-version specific endpoint tables with one based on v2."""
meta = sql.MetaData()
meta.bind = migrate_engine
new_table = sql.Table('endpoint_v3', meta, autoload=True)
new_table.drop()
legacy_table = sql.Table('endpoint_v2', meta, autoload=True)
legacy_table.rename('endpoint')

View File

@ -0,0 +1,96 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2012 OpenStack LLC
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import json
import uuid
import sqlalchemy as sql
from sqlalchemy import orm
ENDPOINT_TYPES = ['public', 'internal', 'admin']
def upgrade(migrate_engine):
"""Split each legacy endpoint into seperate records for each interface."""
meta = sql.MetaData()
meta.bind = migrate_engine
legacy_table = sql.Table('endpoint_v2', meta, autoload=True)
new_table = sql.Table('endpoint_v3', meta, autoload=True)
session = orm.sessionmaker(bind=migrate_engine)()
for ref in session.query(legacy_table).all():
# pull urls out of extra
extra = json.loads(ref.extra)
urls = dict((i, extra.pop('%surl' % i)) for i in ENDPOINT_TYPES)
for interface in ENDPOINT_TYPES:
endpoint = {
'id': uuid.uuid4().hex,
'legacy_endpoint_id': ref.id,
'interface': interface,
'region': ref.region,
'service_id': ref.service_id,
'url': urls[interface],
'extra': json.dumps(extra),
}
session.execute(
'INSERT INTO `%s` (%s) VALUES (%s)' % (
new_table.name,
', '.join('%s' % k for k in endpoint.keys()),
', '.join("'%s'" % v for v in endpoint.values())))
session.commit()
def downgrade(migrate_engine):
"""Re-create the v2 endpoints table based on v3 endpoints."""
meta = sql.MetaData()
meta.bind = migrate_engine
legacy_table = sql.Table('endpoint_v2', meta, autoload=True)
new_table = sql.Table('endpoint_v3', meta, autoload=True)
session = orm.sessionmaker(bind=migrate_engine)()
for ref in session.query(new_table).all():
extra = json.loads(ref.extra)
extra['%surl' % ref.interface] = ref.url
endpoint = {
'id': ref.legacy_endpoint_id,
'region': ref.region,
'service_id': ref.service_id,
'extra': json.dumps(extra),
}
try:
session.execute(
'INSERT INTO `%s` (%s) VALUES (%s)' % (
legacy_table.name,
', '.join('%s' % k for k in endpoint.keys()),
', '.join("'%s'" % v for v in endpoint.values())))
except sql.exc.IntegrityError:
q = session.query(legacy_table)
q = q.filter_by(id=ref.legacy_endpoint_id)
legacy_ref = q.one()
extra = json.loads(legacy_ref.extra)
extra['%surl' % ref.interface] = ref.url
session.execute(
'UPDATE `%s` SET extra=\'%s\' WHERE id="%s"' % (
legacy_table.name,
json.dumps(extra),
legacy_ref.id))
session.commit()

View File

@ -0,0 +1,51 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2012 OpenStack LLC
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import migrate
import sqlalchemy as sql
def upgrade(migrate_engine):
"""Replace API-version specific endpoint tables with one based on v3."""
meta = sql.MetaData()
meta.bind = migrate_engine
legacy_table = sql.Table('endpoint_v2', meta, autoload=True)
legacy_table.drop()
new_table = sql.Table('endpoint_v3', meta, autoload=True)
new_table.rename('endpoint')
def downgrade(migrate_engine):
"""Create API-version specific endpoint tables."""
meta = sql.MetaData()
meta.bind = migrate_engine
new_table = sql.Table('endpoint', meta, autoload=True)
new_table.rename('endpoint_v3')
legacy_table = sql.Table(
'endpoint_v2',
meta,
sql.Column('id', sql.String(64), primary_key=True),
sql.Column('region', sql.String(255)),
sql.Column('service_id',
sql.String(64),
sql.ForeignKey('service.id'),
nullable=False),
sql.Column('extra', sql.Text()))
legacy_table.create(migrate_engine, checkfirst=True)

View File

@ -885,6 +885,9 @@ class CatalogTests(object):
# create an endpoint attached to the service
endpoint = {
'id': uuid.uuid4().hex,
'region': uuid.uuid4().hex,
'interface': uuid.uuid4().hex,
'url': uuid.uuid4().hex,
'service_id': service['id'],
}
self.catalog_api.create_endpoint(endpoint['id'], endpoint)

View File

@ -217,66 +217,52 @@ class SqlToken(SqlTests, test_backend.TokenTests):
class SqlCatalog(SqlTests, test_backend.CatalogTests):
def test_malformed_catalog_throws_error(self):
self.catalog_api.create_service('a', {"id": "a", "desc": "a1",
"name": "b"})
badurl = "http://192.168.1.104:$(compute_port)s/v2/$(tenant)s"
self.catalog_api.create_endpoint('b', {"id": "b", "region": "b1",
"service_id": "a", "adminurl": badurl,
"internalurl": badurl,
"publicurl": badurl})
with self.assertRaises(exception.MalformedEndpoint):
self.catalog_api.get_catalog('fake-user', 'fake-tenant')
def test_get_catalog_without_endpoint(self):
new_service = {
service = {
'id': uuid.uuid4().hex,
'type': uuid.uuid4().hex,
'name': uuid.uuid4().hex,
'description': uuid.uuid4().hex,
}
self.catalog_api.create_service(
new_service['id'],
new_service.copy())
service_id = new_service['id']
self.catalog_api.create_service(service['id'], service.copy())
new_endpoint = {
malformed_url = "http://192.168.1.104:$(compute_port)s/v2/$(tenant)s"
endpoint = {
'id': uuid.uuid4().hex,
'region': uuid.uuid4().hex,
'service_id': service_id,
'service_id': service['id'],
'interface': 'public',
'url': malformed_url,
}
self.catalog_api.create_endpoint(endpoint['id'], endpoint.copy())
self.catalog_api.create_endpoint(
new_endpoint['id'],
new_endpoint.copy())
with self.assertRaises(exception.MalformedEndpoint):
self.catalog_api.get_catalog('fake-user', 'fake-tenant')
def test_get_catalog_with_empty_public_url(self):
service = {
'id': uuid.uuid4().hex,
'type': uuid.uuid4().hex,
'name': uuid.uuid4().hex,
'description': uuid.uuid4().hex,
}
self.catalog_api.create_service(service['id'], service.copy())
endpoint = {
'id': uuid.uuid4().hex,
'region': uuid.uuid4().hex,
'interface': 'public',
'url': '',
'service_id': service['id'],
}
self.catalog_api.create_endpoint(endpoint['id'], endpoint.copy())
catalog = self.catalog_api.get_catalog('user', 'tenant')
service_type = new_service['type']
region = new_endpoint['region']
self.assertEqual(catalog[region][service_type]['name'],
new_service['name'])
self.assertEqual(catalog[region][service_type]['id'],
new_endpoint['id'])
self.assertEqual(catalog[region][service_type]['publicURL'],
"")
self.assertEqual(catalog[region][service_type]['adminURL'],
None)
self.assertEqual(catalog[region][service_type]['internalURL'],
None)
def test_delete_service_with_endpoints(self):
self.catalog_api.create_service('c', {"id": "c", "desc": "a1",
"name": "d"})
self.catalog_api.create_endpoint('d', {"id": "d", "region": None,
"service_id": "c", "adminurl": None,
"internalurl": None,
"publicurl": None})
self.catalog_api.delete_service("c")
self.assertRaises(exception.ServiceNotFound,
self.catalog_man.delete_service, {}, "c")
self.assertRaises(exception.EndpointNotFound,
self.catalog_man.delete_endpoint, {}, "d")
catalog_endpoint = catalog[endpoint['region']][service['type']]
self.assertEqual(catalog_endpoint['name'], service['name'])
self.assertEqual(catalog_endpoint['id'], endpoint['id'])
self.assertEqual(catalog_endpoint['publicURL'], '')
self.assertIsNone(catalog_endpoint.get('adminURL'))
self.assertIsNone(catalog_endpoint.get('internalURL'))
class SqlPolicy(SqlTests, test_backend.PolicyTests):

View File

@ -16,10 +16,10 @@
import copy
import json
import uuid
from migrate.versioning import api as versioning_api
import sqlalchemy
from sqlalchemy.orm import sessionmaker
from keystone.common import sql
from keystone import config
@ -27,6 +27,7 @@ from keystone import test
from keystone.common.sql import migration
import default_fixtures
CONF = config.CONF
@ -38,7 +39,11 @@ class SqlUpgradeTests(test.TestCase):
test.testsdir('backend_sql.conf')])
# create and share a single sqlalchemy engine for testing
self.engine = sql.Base().get_engine(allow_global_engine=False)
base = sql.Base()
self.engine = base.get_engine(allow_global_engine=False)
self.Session = base.get_sessionmaker(
engine=self.engine,
autocommit=False)
self.metadata = sqlalchemy.MetaData()
# populate the engine with tables & fixtures
@ -64,9 +69,7 @@ class SqlUpgradeTests(test.TestCase):
self.assertEqual(expected_cols, actual_cols, '%s table' % table_name)
def test_upgrade_0_to_1(self):
self.assertEqual(self.schema.version, 0, "DB is at version 0")
self._migrate(self.repo_path, 1)
self.assertEqual(self.schema.version, 1, "DB is at version 1")
self.upgrade(1)
self.assertTableColumns("user", ["id", "name", "extra"])
self.assertTableColumns("tenant", ["id", "name", "extra"])
self.assertTableColumns("role", ["id", "name"])
@ -76,23 +79,18 @@ class SqlUpgradeTests(test.TestCase):
self.populate_user_table()
def test_upgrade_5_to_6(self):
self._migrate(self.repo_path, 5)
self.assertEqual(self.schema.version, 5)
self.upgrade(5)
self.assertTableDoesNotExist('policy')
self._migrate(self.repo_path, 6)
self.assertEqual(self.schema.version, 6)
self.upgrade(6)
self.assertTableExists('policy')
self.assertTableColumns('policy', ['id', 'type', 'blob', 'extra'])
def test_upgrade_7_to_9(self):
self.assertEqual(self.schema.version, 0)
self._migrate(self.repo_path, 7)
self.upgrade(7)
self.populate_user_table()
self.populate_tenant_table()
self._migrate(self.repo_path, 9)
self.assertEqual(self.schema.version, 9)
self.upgrade(9)
self.assertTableColumns("user",
["id", "name", "extra", "password",
"enabled"])
@ -103,8 +101,7 @@ class SqlUpgradeTests(test.TestCase):
self.assertTableColumns("user_tenant_membership",
["user_id", "tenant_id"])
self.assertTableColumns("metadata", ["user_id", "tenant_id", "data"])
maker = sessionmaker(bind=self.engine)
session = maker()
session = self.Session()
user_table = sqlalchemy.Table("user",
self.metadata,
autoload=True)
@ -120,25 +117,155 @@ class SqlUpgradeTests(test.TestCase):
session.commit()
def test_downgrade_9_to_7(self):
self.assertEqual(self.schema.version, 0)
self._migrate(self.repo_path, 9)
self._migrate(self.repo_path, 7, False)
self.upgrade(9)
self.downgrade(7)
def test_upgrade_9_to_12(self):
self.upgrade(9)
service_extra = {
'name': uuid.uuid4().hex,
}
service = {
'id': uuid.uuid4().hex,
'type': uuid.uuid4().hex,
'extra': json.dumps(service_extra),
}
endpoint_extra = {
'publicurl': uuid.uuid4().hex,
'internalurl': uuid.uuid4().hex,
'adminurl': uuid.uuid4().hex,
}
endpoint = {
'id': uuid.uuid4().hex,
'region': uuid.uuid4().hex,
'service_id': service['id'],
'extra': json.dumps(endpoint_extra),
}
session = self.Session()
self.insert_dict(session, 'service', service)
self.insert_dict(session, 'endpoint', endpoint)
session.commit()
self.upgrade(12)
self.assertTableColumns(
'service',
['id', 'type', 'extra'])
self.assertTableColumns(
'endpoint',
['id', 'legacy_endpoint_id', 'interface', 'region', 'service_id',
'url', 'extra'])
endpoint_table = sqlalchemy.Table(
'endpoint', self.metadata, autoload=True)
session = self.Session()
self.assertEqual(session.query(endpoint_table).count(), 3)
for interface in ['public', 'internal', 'admin']:
q = session.query(endpoint_table)
q = q.filter_by(legacy_endpoint_id=endpoint['id'])
q = q.filter_by(interface=interface)
ref = q.one()
self.assertNotEqual(ref.id, endpoint['id'])
self.assertEqual(ref.legacy_endpoint_id, endpoint['id'])
self.assertEqual(ref.interface, interface)
self.assertEqual(ref.region, endpoint['region'])
self.assertEqual(ref.service_id, endpoint['service_id'])
self.assertEqual(ref.url, endpoint_extra['%surl' % interface])
self.assertEqual(ref.extra, '{}')
def test_downgrade_12_to_9(self):
self.upgrade(12)
service_extra = {
'name': uuid.uuid4().hex,
}
service = {
'id': uuid.uuid4().hex,
'type': uuid.uuid4().hex,
'extra': json.dumps(service_extra),
}
common_endpoint_attrs = {
'legacy_endpoint_id': uuid.uuid4().hex,
'region': uuid.uuid4().hex,
'service_id': service['id'],
'extra': json.dumps({}),
}
endpoints = {
'public': {
'id': uuid.uuid4().hex,
'interface': 'public',
'url': uuid.uuid4().hex,
},
'internal': {
'id': uuid.uuid4().hex,
'interface': 'internal',
'url': uuid.uuid4().hex,
},
'admin': {
'id': uuid.uuid4().hex,
'interface': 'admin',
'url': uuid.uuid4().hex,
},
}
session = self.Session()
self.insert_dict(session, 'service', service)
for endpoint in endpoints.values():
endpoint.update(common_endpoint_attrs)
self.insert_dict(session, 'endpoint', endpoint)
session.commit()
self.downgrade(8)
self.assertTableColumns(
'service',
['id', 'type', 'extra'])
self.assertTableColumns(
'endpoint',
['id', 'region', 'service_id', 'extra'])
endpoint_table = sqlalchemy.Table(
'endpoint', self.metadata, autoload=True)
session = self.Session()
self.assertEqual(session.query(endpoint_table).count(), 1)
q = session.query(endpoint_table)
q = q.filter_by(id=common_endpoint_attrs['legacy_endpoint_id'])
ref = q.one()
self.assertEqual(ref.id, common_endpoint_attrs['legacy_endpoint_id'])
self.assertEqual(ref.region, endpoint['region'])
self.assertEqual(ref.service_id, endpoint['service_id'])
extra = json.loads(ref.extra)
for interface in ['public', 'internal', 'admin']:
expected_url = endpoints[interface]['url']
self.assertEqual(extra['%surl' % interface], expected_url)
def insert_dict(self, session, table_name, d):
"""Naively inserts key-value pairs into a table, given a dictionary."""
session.execute(
'INSERT INTO `%s` (%s) VALUES (%s)' % (
table_name,
', '.join('%s' % k for k in d.keys()),
', '.join("'%s'" % v for v in d.values())))
def test_downgrade_to_0(self):
self._migrate(self.repo_path, 9)
self._migrate(self.repo_path, 0, False)
self.upgrade(12)
self.downgrade(0)
for table_name in ["user", "token", "role", "user_tenant_membership",
"metadata"]:
self.assertTableDoesNotExist(table_name)
def test_upgrade_6_to_7(self):
self._migrate(self.repo_path, 6)
self.assertEqual(self.schema.version, 6, "DB is at version 6")
self.upgrade(6)
self.assertTableDoesNotExist('credential')
self.assertTableDoesNotExist('domain')
self.assertTableDoesNotExist('user_domain_metadata')
self._migrate(self.repo_path, 7)
self.assertEqual(self.schema.version, 7, "DB is at version 7")
self.upgrade(7)
self.assertTableExists('credential')
self.assertTableColumns('credential', ['id', 'user_id', 'project_id',
'blob', 'type', 'extra'])
@ -191,12 +318,20 @@ class SqlUpgradeTests(test.TestCase):
else:
raise AssertionError('Table "%s" already exists' % table_name)
def _migrate(self, repository, version, upgrade=True):
err = ""
def upgrade(self, *args, **kwargs):
self._migrate(*args, **kwargs)
def downgrade(self, *args, **kwargs):
self._migrate(*args, downgrade=True, **kwargs)
def _migrate(self, version, repository=None, downgrade=False):
repository = repository or self.repo_path
err = ''
version = versioning_api._migrate_version(self.schema,
version,
upgrade,
not downgrade,
err)
changeset = self.schema.changeset(version)
for ver, change in changeset:
self.schema.runchange(ver, change, changeset.step)
self.assertEqual(self.schema.version, version)

View File

@ -44,6 +44,7 @@ class RestfulTestCase(test_content_types.RestfulTestCase):
ref = self.new_ref()
ref['interface'] = uuid.uuid4().hex
ref['service_id'] = service_id
ref['url'] = uuid.uuid4().hex
return ref
def new_domain_ref(self):