diff --git a/nova/db/sqlalchemy/migrate_repo/versions/156_cidr_column_length.py b/nova/db/sqlalchemy/migrate_repo/versions/156_cidr_column_length.py new file mode 100644 index 000000000000..fda0c50750a3 --- /dev/null +++ b/nova/db/sqlalchemy/migrate_repo/versions/156_cidr_column_length.py @@ -0,0 +1,56 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 IBM Corp. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from sqlalchemy import MetaData, String, Table +from sqlalchemy.dialects import postgresql + + +CIDR_TABLE_COLUMNS = [ + # table name, column name + ('security_group_rules', 'cidr'), + ('provider_fw_rules', 'cidr'), + ('networks', 'cidr'), + ('networks', 'cidr_v6')] + + +def upgrade(migrate_engine): + """Convert String columns holding IP addresses to INET for postgresql.""" + meta = MetaData() + meta.bind = migrate_engine + dialect = migrate_engine.url.get_dialect() + + if dialect is postgresql.dialect: + for table, column in CIDR_TABLE_COLUMNS: + # can't use migrate's alter() because it does not support + # explicit casting + migrate_engine.execute( + "ALTER TABLE %(table)s " + "ALTER COLUMN %(column)s TYPE INET USING %(column)s::INET" + % locals()) + else: + for table, column in CIDR_TABLE_COLUMNS: + t = Table(table, meta, autoload=True) + getattr(t.c, column).alter(type=String(43)) + + +def downgrade(migrate_engine): + """Convert columns back to the larger String(255).""" + meta = MetaData() + meta.bind = migrate_engine + for table, column in CIDR_TABLE_COLUMNS: + t = Table(table, meta, autoload=True) + getattr(t.c, column).alter(type=String(39)) diff --git a/nova/db/sqlalchemy/models.py b/nova/db/sqlalchemy/models.py index 5eeae30dc7b4..28d8f0882514 100644 --- a/nova/db/sqlalchemy/models.py +++ b/nova/db/sqlalchemy/models.py @@ -529,7 +529,7 @@ class SecurityGroupIngressRule(BASE, NovaBase): protocol = Column(String(5)) # "tcp", "udp", or "icmp" from_port = Column(Integer) to_port = Column(Integer) - cidr = Column(types.IPAddress()) + cidr = Column(types.CIDR()) # Note: This is not the parent SecurityGroup. It's SecurityGroup we're # granting access for. @@ -549,7 +549,7 @@ class ProviderFirewallRule(BASE, NovaBase): protocol = Column(String(5)) # "tcp", "udp", or "icmp" from_port = Column(Integer) to_port = Column(Integer) - cidr = Column(types.IPAddress()) + cidr = Column(types.CIDR()) class KeyPair(BASE, NovaBase): @@ -599,8 +599,8 @@ class Network(BASE, NovaBase): label = Column(String(255)) injected = Column(Boolean, default=False) - cidr = Column(types.IPAddress(), unique=True) - cidr_v6 = Column(types.IPAddress(), unique=True) + cidr = Column(types.CIDR(), unique=True) + cidr_v6 = Column(types.CIDR(), unique=True) multi_host = Column(Boolean, default=False) gateway_v6 = Column(types.IPAddress()) diff --git a/nova/db/sqlalchemy/types.py b/nova/db/sqlalchemy/types.py index ef861b83218d..5a04a025317e 100644 --- a/nova/db/sqlalchemy/types.py +++ b/nova/db/sqlalchemy/types.py @@ -36,3 +36,15 @@ class IPAddress(types.TypeDecorator): elif utils.is_valid_ipv6(value): return utils.get_shortened_ipv6(value) return value + + +class CIDR(types.TypeDecorator): + """An SQLAlchemy type representing a CIDR definition.""" + impl = types.String(43).with_variant(postgresql.INET(), 'postgresql') + + def process_bind_param(self, value, dialect): + """Process/Formats the value before insert it into the db.""" + # NOTE(sdague): normalize all the inserts + if utils.is_valid_ipv6_cidr(value): + return utils.get_shortened_ipv6_cidr(value) + return value diff --git a/nova/tests/test_migrations.py b/nova/tests/test_migrations.py index e71b97513a59..fb7411cd0d86 100644 --- a/nova/tests/test_migrations.py +++ b/nova/tests/test_migrations.py @@ -507,28 +507,52 @@ class TestMigrations(BaseMigrationTestCase): # migration 149, changes IPAddr storage format def _prerun_149(self, engine): provider_fw_rules = get_table(engine, 'provider_fw_rules') - data = [ - {'protocol': 'tcp', 'from_port': 1234, - 'to_port': 1234, 'cidr': "127.0.0.1"}, - {'protocol': 'tcp', 'from_port': 1234, - 'to_port': 1234, 'cidr': "255.255.255.255"}, - {'protocol': 'tcp', 'from_port': 1234, - 'to_port': 1234, 'cidr': "2001:db8::1:2"}, - {'protocol': 'tcp', 'from_port': 1234, - 'to_port': 1234, 'cidr': "::1"} - ] - engine.execute(provider_fw_rules.insert(), data) + console_pools = get_table(engine, 'console_pools') + data = { + 'provider_fw_rules': + [ + {'protocol': 'tcp', 'from_port': 1234, + 'to_port': 1234, 'cidr': "127.0.0.1/30"}, + {'protocol': 'tcp', 'from_port': 1234, + 'to_port': 1234, 'cidr': "128.128.128.128/16"}, + {'protocol': 'tcp', 'from_port': 1234, + 'to_port': 1234, 'cidr': "2001:db8::1:2/48"}, + {'protocol': 'tcp', 'from_port': 1234, + 'to_port': 1234, 'cidr': "::1/64"} + ], + 'console_pools': + [ + {'address': '10.10.10.10'}, + {'address': '128.100.100.100'}, + {'address': '2002:2002:2002:2002:2002:2002:2002:2002'}, + {'address': '::1'}, + ] + } + + engine.execute(provider_fw_rules.insert(), data['provider_fw_rules']) + + for pool in data['console_pools']: + engine.execute(console_pools.insert(), pool) + return data def _check_149(self, engine, data): provider_fw_rules = get_table(engine, 'provider_fw_rules') result = provider_fw_rules.select().execute() - iplist = map(lambda x: x['cidr'], data) + iplist = map(lambda x: x['cidr'], data['provider_fw_rules']) for row in result: self.assertIn(row['cidr'], iplist) + console_pools = get_table(engine, 'console_pools') + result = console_pools.select().execute() + + iplist = map(lambda x: x['address'], data['console_pools']) + + for row in result: + self.assertIn(row['address'], iplist) + # migration 151 - changes period_beginning and period_ending to DateTime def _prerun_151(self, engine): task_log = get_table(engine, 'task_log') @@ -703,3 +727,32 @@ class TestMigrations(BaseMigrationTestCase): # override __eq__, but if we stringify them then they do. self.assertEqual(str(base_column.type), str(shadow_column.type)) + + # migration 156 - introduce CIDR type + def _prerun_156(self, engine): + # assume the same data as from 149 + data = { + 'provider_fw_rules': + [ + {'protocol': 'tcp', 'from_port': 1234, + 'to_port': 1234, 'cidr': "127.0.0.1/30"}, + {'protocol': 'tcp', 'from_port': 1234, + 'to_port': 1234, 'cidr': "128.128.128.128/16"}, + {'protocol': 'tcp', 'from_port': 1234, + 'to_port': 1234, 'cidr': "2001:db8::1:2/48"}, + {'protocol': 'tcp', 'from_port': 1234, + 'to_port': 1234, 'cidr': "::1/64"} + ], + 'console_pools': + [ + {'address': '10.10.10.10'}, + {'address': '128.100.100.100'}, + {'address': '2002:2002:2002:2002:2002:2002:2002:2002'}, + {'address': '::1'}, + ] + } + return data + + def _check_156(self, engine, data): + # recheck the 149 data + self._check_149(engine, data) diff --git a/nova/tests/test_utils.py b/nova/tests/test_utils.py index 0aa2a310c074..c4e3ea49fbe4 100644 --- a/nova/tests/test_utils.py +++ b/nova/tests/test_utils.py @@ -491,6 +491,17 @@ class GenericUtilsTestCase(test.TestCase): self.assertFalse(utils.is_valid_ipv6("foo")) self.assertFalse(utils.is_valid_ipv6("127.0.0.1")) + def test_is_valid_ipv6_cidr(self): + self.assertTrue(utils.is_valid_ipv6_cidr("2600::/64")) + self.assertTrue(utils.is_valid_ipv6_cidr( + "abcd:ef01:2345:6789:abcd:ef01:192.168.254.254/48")) + self.assertTrue(utils.is_valid_ipv6_cidr( + "0000:0000:0000:0000:0000:0000:0000:0001/32")) + self.assertTrue(utils.is_valid_ipv6_cidr( + "0000:0000:0000:0000:0000:0000:0000:0001")) + self.assertFalse(utils.is_valid_ipv6_cidr("foo")) + self.assertFalse(utils.is_valid_ipv6_cidr("127.0.0.1")) + def test_get_shortened_ipv6(self): self.assertEquals("abcd:ef01:2345:6789:abcd:ef01:c0a8:fefe", utils.get_shortened_ipv6( @@ -505,6 +516,18 @@ class GenericUtilsTestCase(test.TestCase): self.assertRaises(netaddr.AddrFormatError, utils.get_shortened_ipv6, "failure") + def test_get_shortened_ipv6_cidr(self): + self.assertEquals("2600::/64", utils.get_shortened_ipv6_cidr( + "2600:0000:0000:0000:0000:0000:0000:0000/64")) + self.assertEquals("2600::/64", utils.get_shortened_ipv6_cidr( + "2600::1/64")) + self.assertRaises(netaddr.AddrFormatError, + utils.get_shortened_ipv6_cidr, + "127.0.0.1") + self.assertRaises(netaddr.AddrFormatError, + utils.get_shortened_ipv6_cidr, + "failure") + class MonkeyPatchTestCase(test.TestCase): """Unit test for utils.monkey_patch().""" diff --git a/nova/utils.py b/nova/utils.py index aaf25814289d..126e3c7e47fb 100644 --- a/nova/utils.py +++ b/nova/utils.py @@ -896,11 +896,24 @@ def is_valid_ipv6(address): return netaddr.valid_ipv6(address) +def is_valid_ipv6_cidr(address): + try: + str(netaddr.IPNetwork(address, version=6).cidr) + return True + except Exception: + return False + + def get_shortened_ipv6(address): addr = netaddr.IPAddress(address, version=6) return str(addr.ipv6()) +def get_shortened_ipv6_cidr(address): + net = netaddr.IPNetwork(address, version=6) + return str(net.cidr) + + def is_valid_cidr(address): """Check if the provided ipv4 or ipv6 address is a valid CIDR address or not"""