Default security group table

This change prevents the race condition by enforcing a single default
security group via new table default_security_group. It has tenant_id
as primary key and security_group_id, which is id of default
security group. Migration that inroduces this table has sanity check that
verifies that there is no duplicate default security group in any
tenant.

This idea has come up from discussion in comments to
https://review.openstack.org/135006

DocImpact

Closes-bug: #1194579

Change-Id: Ifa8fbddd22bce4c50836cf443ebe10dff37443ef
This commit is contained in:
Ann Kamyshnikova 2014-12-12 15:30:06 +03:00
parent 910470de36
commit 79c97120de
6 changed files with 176 additions and 10 deletions

View File

@ -0,0 +1,97 @@
# Copyright 2015 OpenStack Foundation
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
""" Add default security group table
Revision ID: 14be42f3d0a5
Revises: 41662e32bce2
Create Date: 2014-12-12 14:54:11.123635
"""
# revision identifiers, used by Alembic.
revision = '14be42f3d0a5'
down_revision = '26b54cf9024d'
from alembic import op
import sqlalchemy as sa
from neutron.common import exceptions
# Models can change in time, but migration should rely only on exact
# model state at the current moment, so a separate model is created
# here.
security_group = sa.Table('securitygroups', sa.MetaData(),
sa.Column('id', sa.String(length=36),
nullable=False),
sa.Column('name', sa.String(255)),
sa.Column('tenant_id', sa.String(255)))
class DuplicateSecurityGroupsNamedDefault(exceptions.Conflict):
message = _("Some tenants have more than one security group named "
"'default': %(duplicates)s. All duplicate 'default' security "
"groups must be resolved before upgrading the database.")
def upgrade():
table = op.create_table(
'default_security_group',
sa.Column('tenant_id', sa.String(length=255), nullable=False),
sa.Column('security_group_id', sa.String(length=36), nullable=False),
sa.PrimaryKeyConstraint('tenant_id'),
sa.ForeignKeyConstraint(['security_group_id'],
['securitygroups.id'],
ondelete="CASCADE"))
sel = (sa.select([security_group.c.tenant_id,
security_group.c.id])
.where(security_group.c.name == 'default'))
ins = table.insert(inline=True).from_select(['tenant_id',
'security_group_id'], sel)
op.execute(ins)
def downgrade():
op.drop_table('default_security_group')
def check_sanity(connection):
res = get_duplicate_default_security_groups(connection)
if res:
raise DuplicateSecurityGroupsNamedDefault(
duplicates='; '.join('tenant %s: %s' %
(tenant_id, ', '.join(groups))
for tenant_id, groups in res.iteritems()))
def get_duplicate_default_security_groups(connection):
insp = sa.engine.reflection.Inspector.from_engine(connection)
if 'securitygroups' not in insp.get_table_names():
return {}
session = sa.orm.Session(bind=connection.connect())
subq = (session.query(security_group.c.tenant_id)
.filter(security_group.c.name == 'default')
.group_by(security_group.c.tenant_id)
.having(sa.func.count() > 1)
.subquery())
sg = (session.query(security_group)
.join(subq, security_group.c.tenant_id == subq.c.tenant_id)
.filter(security_group.c.name == 'default')
.all())
res = {}
for s in sg:
res.setdefault(s.tenant_id, []).append(s.id)
return res

View File

@ -1 +1 @@
26b54cf9024d 14be42f3d0a5

View File

@ -17,6 +17,7 @@ import six
from alembic import command as alembic_command from alembic import command as alembic_command
from alembic import config as alembic_config from alembic import config as alembic_config
from alembic import environment
from alembic import script as alembic_script from alembic import script as alembic_script
from alembic import util as alembic_util from alembic import util as alembic_util
from oslo.config import cfg from oslo.config import cfg
@ -89,7 +90,8 @@ def do_upgrade_downgrade(config, cmd):
revision = sign + str(CONF.command.delta) revision = sign + str(CONF.command.delta)
else: else:
revision = CONF.command.revision revision = CONF.command.revision
if CONF.command.name == 'upgrade' and not CONF.command.sql:
run_sanity_checks(config, revision)
do_alembic_command(config, cmd, revision, sql=CONF.command.sql) do_alembic_command(config, cmd, revision, sql=CONF.command.sql)
@ -191,6 +193,22 @@ def get_alembic_config():
return config return config
def run_sanity_checks(config, revision):
script_dir = alembic_script.ScriptDirectory.from_config(config)
def check_sanity(rev, context):
for script in script_dir.iterate_revisions(revision, rev):
if hasattr(script.module, 'check_sanity'):
script.module.check_sanity(context.connection)
return []
with environment.EnvironmentContext(config, script_dir,
fn=check_sanity,
starting_rev=None,
destination_rev=revision):
script_dir.run_env()
def main(): def main():
CONF(project='neutron') CONF(project='neutron')
config = get_alembic_config() config = get_alembic_config()

View File

@ -40,6 +40,21 @@ class SecurityGroup(model_base.BASEV2, models_v2.HasId, models_v2.HasTenant):
description = sa.Column(sa.String(255)) description = sa.Column(sa.String(255))
class DefaultSecurityGroup(model_base.BASEV2):
__tablename__ = 'default_security_group'
tenant_id = sa.Column(sa.String(255), primary_key=True, nullable=False)
security_group_id = sa.Column(sa.String(36),
sa.ForeignKey("securitygroups.id",
ondelete="CASCADE"),
nullable=False)
security_group = orm.relationship(
SecurityGroup, lazy='joined',
backref=orm.backref('default_security_group', cascade='all,delete'),
primaryjoin="SecurityGroup.id==DefaultSecurityGroup.security_group_id",
)
class SecurityGroupPortBinding(model_base.BASEV2): class SecurityGroupPortBinding(model_base.BASEV2):
"""Represents binding between neutron ports and security profiles.""" """Represents binding between neutron ports and security profiles."""
@ -118,8 +133,12 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
tenant_id=tenant_id, tenant_id=tenant_id,
name=s['name']) name=s['name'])
context.session.add(security_group_db) context.session.add(security_group_db)
if default_sg:
context.session.add(DefaultSecurityGroup(
security_group=security_group_db,
tenant_id=security_group_db['tenant_id']))
for ethertype in ext_sg.sg_supported_ethertypes: for ethertype in ext_sg.sg_supported_ethertypes:
if s.get('name') == 'default': if default_sg:
# Allow intercommunication # Allow intercommunication
ingress_rule = SecurityGroupRule( ingress_rule = SecurityGroupRule(
id=uuidutils.generate_uuid(), tenant_id=tenant_id, id=uuidutils.generate_uuid(), tenant_id=tenant_id,
@ -503,19 +522,21 @@ class SecurityGroupDbMixin(ext_sg.SecurityGroupPluginBase):
:returns: the default security group id. :returns: the default security group id.
""" """
filters = {'name': ['default'], 'tenant_id': [tenant_id]} query = self._model_query(context, DefaultSecurityGroup)
default_group = self.get_security_groups(context, filters, try:
default_sg=True) default_group = query.filter(
if not default_group: DefaultSecurityGroup.tenant_id == tenant_id).one()
except exc.NoResultFound:
security_group = { security_group = {
'security_group': {'name': 'default', 'security_group': {'name': 'default',
'tenant_id': tenant_id, 'tenant_id': tenant_id,
'description': _('Default security group')} 'description': _('Default security group')}
} }
ret = self.create_security_group(context, security_group, True) ret = self.create_security_group(context, security_group,
default_sg=True)
return ret['id'] return ret['id']
else: else:
return default_group[0]['id'] return default_group['security_group_id']
def _get_security_groups_on_port(self, context, port): def _get_security_groups_on_port(self, context, port):
"""Check that all security groups on port belong to tenant. """Check that all security groups on port belong to tenant.

View File

@ -18,6 +18,7 @@ import pprint
import alembic import alembic
import alembic.autogenerate import alembic.autogenerate
import alembic.migration import alembic.migration
from alembic import script as alembic_script
import mock import mock
from oslo.config import cfg from oslo.config import cfg
from oslo.config import fixture as config_fixture from oslo.config import fixture as config_fixture
@ -231,3 +232,31 @@ class TestModelsMigrationsMysql(_TestModelsMigrations,
class TestModelsMigrationsPsql(_TestModelsMigrations, class TestModelsMigrationsPsql(_TestModelsMigrations,
test_base.PostgreSQLOpportunisticTestCase): test_base.PostgreSQLOpportunisticTestCase):
pass pass
class TestSanityCheck(test_base.DbTestCase):
def setUp(self):
super(TestSanityCheck, self).setUp()
self.alembic_config = migration.get_alembic_config()
self.alembic_config.neutron_config = cfg.CONF
def test_check_sanity_14be42f3d0a5(self):
SecurityGroup = sqlalchemy.Table(
'securitygroups', sqlalchemy.MetaData(),
sqlalchemy.Column('id', sqlalchemy.String(length=36),
nullable=False),
sqlalchemy.Column('name', sqlalchemy.String(255)),
sqlalchemy.Column('tenant_id', sqlalchemy.String(255)))
with self.engine.connect() as conn:
SecurityGroup.create(conn)
conn.execute(SecurityGroup.insert(), [
{'id': '123d4s', 'tenant_id': 'sssda1', 'name': 'default'},
{'id': '123d4', 'tenant_id': 'sssda1', 'name': 'default'}
])
script_dir = alembic_script.ScriptDirectory.from_config(
self.alembic_config)
script = script_dir.get_revision("14be42f3d0a5").module
self.assertRaises(script.DuplicateSecurityGroupsNamedDefault,
script.check_sanity, conn)

View File

@ -76,7 +76,8 @@ class TestCli(base.BaseTestCase):
self.mock_alembic_err.side_effect = SystemExit self.mock_alembic_err.side_effect = SystemExit
def _main_test_helper(self, argv, func_name, exp_args=(), exp_kwargs={}): def _main_test_helper(self, argv, func_name, exp_args=(), exp_kwargs={}):
with mock.patch.object(sys, 'argv', argv): with mock.patch.object(sys, 'argv', argv), mock.patch.object(
cli, 'run_sanity_checks'):
cli.main() cli.main()
self.do_alembic_cmd.assert_has_calls( self.do_alembic_cmd.assert_has_calls(
[mock.call(mock.ANY, func_name, *exp_args, **exp_kwargs)] [mock.call(mock.ANY, func_name, *exp_args, **exp_kwargs)]