diff --git a/neutron_lib/db/sqlalchemytypes.py b/neutron_lib/db/sqlalchemytypes.py new file mode 100644 index 000000000..5490a142b --- /dev/null +++ b/neutron_lib/db/sqlalchemytypes.py @@ -0,0 +1,83 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Custom SQLAlchemy types.""" + +import netaddr +from sqlalchemy import types + +from neutron_lib._i18n import _ + + +class IPAddress(types.TypeDecorator): + + impl = types.String(64) + + def process_result_value(self, value, dialect): + return netaddr.IPAddress(value) + + def process_bind_param(self, value, dialect): + if not isinstance(value, netaddr.IPAddress): + raise AttributeError(_("Received type '%(type)s' and value " + "'%(value)s'. Expecting netaddr.IPAddress " + "type.") % {'type': type(value), + 'value': value}) + + return str(value) + + +class CIDR(types.TypeDecorator): + + impl = types.String(64) + + def process_result_value(self, value, dialect): + return netaddr.IPNetwork(value) + + def process_bind_param(self, value, dialect): + if not isinstance(value, netaddr.IPNetwork): + raise AttributeError(_("Received type '%(type)s' and value " + "'%(value)s'. Expecting netaddr.IPNetwork " + "type.") % {'type': type(value), + 'value': value}) + return str(value) + + +class MACAddress(types.TypeDecorator): + + impl = types.String(64) + + def process_result_value(self, value, dialect): + return netaddr.EUI(value) + + def process_bind_param(self, value, dialect): + if not isinstance(value, netaddr.EUI): + raise AttributeError(_("Received type '%(type)s' and value " + "'%(value)s'. Expecting netaddr.EUI " + "type.") % {'type': type(value), + 'value': value}) + return str(value) + + +class TruncatedDateTime(types.TypeDecorator): + """Truncates microseconds. + + Use this for datetime fields so we don't have to worry about DB-specific + behavior when it comes to rounding/truncating microseconds off of + timestamps. + """ + + impl = types.DateTime + + def process_bind_param(self, value, dialect): + return value.replace(microsecond=0) if value else value + + process_result_value = process_bind_param diff --git a/neutron_lib/tests/_base.py b/neutron_lib/tests/_base.py index 7f4cb66cc..148e540f6 100644 --- a/neutron_lib/tests/_base.py +++ b/neutron_lib/tests/_base.py @@ -32,7 +32,7 @@ from neutron_lib import exceptions from neutron_lib import fixture from neutron_lib.tests import _post_mortem_debug as post_mortem_debug -from neutron_lib.tests import _tools as tools +from neutron_lib.tests import tools CONF = cfg.CONF diff --git a/neutron_lib/tests/_tools.py b/neutron_lib/tests/tools.py similarity index 85% rename from neutron_lib/tests/_tools.py rename to neutron_lib/tests/tools.py index 0992c2c61..770ad251d 100644 --- a/neutron_lib/tests/_tools.py +++ b/neutron_lib/tests/tools.py @@ -14,6 +14,7 @@ # under the License. import platform +import random import warnings import fixtures @@ -57,3 +58,11 @@ def is_bsd(): if 'bsd' in system.lower(): return True return False + + +def get_random_cidr(version=4): + if version == 4: + return '10.%d.%d.0/%d' % (random.randint(3, 254), + random.randint(3, 254), + 24) + return '2001:db8:%x::/%d' % (random.getrandbits(16), 64) diff --git a/neutron_lib/tests/unit/api/test_conversions.py b/neutron_lib/tests/unit/api/test_conversions.py index b98dee4cd..78794ed02 100644 --- a/neutron_lib/tests/unit/api/test_conversions.py +++ b/neutron_lib/tests/unit/api/test_conversions.py @@ -22,7 +22,7 @@ from neutron_lib.api import converters from neutron_lib import constants from neutron_lib import exceptions as n_exc from neutron_lib.tests import _base as base -from neutron_lib.tests import _tools as tools +from neutron_lib.tests import tools class TestConvertToBoolean(base.BaseTestCase): diff --git a/neutron_lib/tests/unit/db/test_sqlalchemytypes.py b/neutron_lib/tests/unit/db/test_sqlalchemytypes.py new file mode 100644 index 000000000..6542b9f19 --- /dev/null +++ b/neutron_lib/tests/unit/db/test_sqlalchemytypes.py @@ -0,0 +1,260 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import abc + +import netaddr +from oslo_db import exception +from oslo_db.sqlalchemy import enginefacade +from oslo_db.sqlalchemy import test_fixtures +from oslo_utils import timeutils +from oslo_utils import uuidutils +import six +import sqlalchemy as sa + +from neutron_lib import context +from neutron_lib.db import sqlalchemytypes +from neutron_lib.tests import _base as test_base +from neutron_lib.tests import tools +from neutron_lib.utils import net + + +@six.add_metaclass(abc.ABCMeta) +class SqlAlchemyTypesBaseTestCase(test_fixtures.OpportunisticDBTestMixin, + test_base.BaseTestCase): + def setUp(self): + super(SqlAlchemyTypesBaseTestCase, self).setUp() + self.engine = enginefacade.writer.get_engine() + meta = sa.MetaData(bind=self.engine) + self.test_table = self._get_test_table(meta) + self.test_table.create() + self.addCleanup(meta.drop_all) + self.ctxt = context.get_admin_context() + + @abc.abstractmethod + def _get_test_table(self, meta): + """Returns a new sa.Table() object for this test case.""" + + def _add_row(self, **kargs): + self.engine.execute(self.test_table.insert().values(**kargs)) + + def _get_all(self): + rows_select = self.test_table.select() + return self.engine.execute(rows_select).fetchall() + + def _update_row(self, **kargs): + self.engine.execute(self.test_table.update().values(**kargs)) + + def _delete_rows(self): + self.engine.execute(self.test_table.delete()) + + def _validate_crud(self, data_field_name, expected=None): + objs = self._get_all() + self.assertEqual(len(expected) if expected else 0, len(objs)) + if expected: + for obj in objs: + name = obj['id'] + self.assertEqual(expected[name], obj[data_field_name]) + + +class IPAddressTestCase(SqlAlchemyTypesBaseTestCase): + + def _get_test_table(self, meta): + return sa.Table( + 'fakeipaddressmodels', + meta, + sa.Column('id', sa.String(36), primary_key=True, nullable=False), + sa.Column('ip', sqlalchemytypes.IPAddress)) + + def _validate_ip_address(self, data_field_name, expected=None): + objs = self._get_all() + self.assertEqual(len(expected) if expected else 0, len(objs)) + if expected: + for obj in objs: + name = obj['id'] + self.assertEqual(expected[name], obj[data_field_name]) + + def _test_crud(self, ip_addresses): + ip = netaddr.IPAddress(ip_addresses[0]) + self._add_row(id='fake_id', ip=ip) + self._validate_ip_address(data_field_name='ip', + expected={'fake_id': ip}) + + ip2 = netaddr.IPAddress(ip_addresses[1]) + self._update_row(ip=ip2) + self._validate_ip_address(data_field_name='ip', + expected={'fake_id': ip2}) + + self._delete_rows() + self._validate_ip_address(data_field_name='ip', expected=None) + + def test_crud(self): + ip_addresses = ["10.0.0.1", "10.0.0.2"] + self._test_crud(ip_addresses) + + ip_addresses = ["2210::ffff:ffff:ffff:ffff", + "2120::ffff:ffff:ffff:ffff"] + self._test_crud(ip_addresses) + + def test_wrong_type(self): + self.assertRaises(exception.DBError, self._add_row, + id='fake_id', ip="") + self.assertRaises(exception.DBError, self._add_row, + id='fake_id', ip="10.0.0.5") + + def _test_multiple_create(self, entries): + reference = {} + for entry in entries: + ip = netaddr.IPAddress(entry['ip']) + name = entry['name'] + self._add_row(id=name, ip=ip) + reference[name] = ip + + self._validate_ip_address(data_field_name='ip', expected=reference) + self._delete_rows() + self._validate_ip_address(data_field_name='ip', expected=None) + + def test_multiple_create(self): + ip_addresses = [ + {'name': 'fake_id1', 'ip': "10.0.0.5"}, + {'name': 'fake_id2', 'ip': "10.0.0.1"}, + {'name': 'fake_id3', + 'ip': "2210::ffff:ffff:ffff:ffff"}, + {'name': 'fake_id4', + 'ip': "2120::ffff:ffff:ffff:ffff"}] + self._test_multiple_create(ip_addresses) + + +class CIDRTestCase(SqlAlchemyTypesBaseTestCase): + + def _get_test_table(self, meta): + return sa.Table( + 'fakecidrmodels', + meta, + sa.Column('id', sa.String(36), primary_key=True, nullable=False), + sa.Column('cidr', sqlalchemytypes.CIDR) + ) + + def _get_one(self, value): + row_select = self.test_table.select().\ + where(self.test_table.c.cidr == value) + return self.engine.execute(row_select).first() + + def _update_row(self, key, cidr): + self.engine.execute( + self.test_table.update().values(cidr=cidr). + where(self.test_table.c.cidr == key)) + + def test_crud(self): + cidrs = ["10.0.0.0/24", "10.123.250.9/32", "2001:db8::/42", + "fe80::21e:67ff:fed0:56f0/64"] + + for cidr_str in cidrs: + cidr = netaddr.IPNetwork(cidr_str) + self._add_row(id=uuidutils.generate_uuid(), cidr=cidr) + obj = self._get_one(cidr) + self.assertEqual(cidr, obj['cidr']) + random_cidr = netaddr.IPNetwork(tools.get_random_cidr()) + self._update_row(cidr, random_cidr) + obj = self._get_one(random_cidr) + self.assertEqual(random_cidr, obj['cidr']) + + objs = self._get_all() + self.assertEqual(len(cidrs), len(objs)) + self._delete_rows() + objs = self._get_all() + self.assertEqual(0, len(objs)) + + def test_wrong_cidr(self): + wrong_cidrs = ["10.500.5.0/24", "10.0.0.1/40", "10.0.0.10.0/24", + "cidr", "", '2001:db8:5000::/64', '2001:db8::/130'] + for cidr in wrong_cidrs: + self.assertRaises(exception.DBError, self._add_row, + id=uuidutils.generate_uuid(), cidr=cidr) + + +class MACAddressTestCase(SqlAlchemyTypesBaseTestCase): + + def _get_test_table(self, meta): + return sa.Table( + 'fakemacaddressmodels', + meta, + sa.Column('id', sa.String(36), primary_key=True, nullable=False), + sa.Column('mac', sqlalchemytypes.MACAddress) + ) + + def _get_one(self, value): + row_select = self.test_table.select().\ + where(self.test_table.c.mac == value) + return self.engine.execute(row_select).first() + + def _get_all(self): + rows_select = self.test_table.select() + return self.engine.execute(rows_select).fetchall() + + def _update_row(self, key, mac): + self.engine.execute( + self.test_table.update().values(mac=mac). + where(self.test_table.c.mac == key)) + + def _delete_row(self): + self.engine.execute( + self.test_table.delete()) + + def test_crud(self): + mac_addresses = ['FA:16:3E:00:00:01', 'FA:16:3E:00:00:02'] + + for mac in mac_addresses: + mac = netaddr.EUI(mac) + self._add_row(id=uuidutils.generate_uuid(), mac=mac) + obj = self._get_one(mac) + self.assertEqual(mac, obj['mac']) + random_mac = netaddr.EUI(net.get_random_mac( + ['fe', '16', '3e', '00', '00', '00'])) + self._update_row(mac, random_mac) + obj = self._get_one(random_mac) + self.assertEqual(random_mac, obj['mac']) + + objs = self._get_all() + self.assertEqual(len(mac_addresses), len(objs)) + self._delete_rows() + objs = self._get_all() + self.assertEqual(0, len(objs)) + + def test_wrong_mac(self): + wrong_macs = ["fake", "", -1, + "FK:16:3E:00:00:02", + "FA:16:3E:00:00:020"] + for mac in wrong_macs: + self.assertRaises(exception.DBError, self._add_row, + id=uuidutils.generate_uuid(), mac=mac) + + +class TruncatedDateTimeTestCase(SqlAlchemyTypesBaseTestCase): + + def _get_test_table(self, meta): + return sa.Table( + 'timetable', + meta, + sa.Column('id', sa.String(36), primary_key=True, nullable=False), + sa.Column('thetime', sqlalchemytypes.TruncatedDateTime) + ) + + def test_microseconds_truncated(self): + tstamp = timeutils.utcnow() + tstamp_low = tstamp.replace(microsecond=111111) + tstamp_high = tstamp.replace(microsecond=999999) + self._add_row(id=1, thetime=tstamp_low) + self._add_row(id=2, thetime=tstamp_high) + rows = self._get_all() + self.assertEqual(2, len(rows)) + self.assertEqual(rows[0].thetime, rows[1].thetime) diff --git a/releasenotes/notes/rehome-sqlalchemytypes-14817eb6694463db.yaml b/releasenotes/notes/rehome-sqlalchemytypes-14817eb6694463db.yaml new file mode 100644 index 000000000..f7d8870e7 --- /dev/null +++ b/releasenotes/notes/rehome-sqlalchemytypes-14817eb6694463db.yaml @@ -0,0 +1,5 @@ +--- +features: + - The ``neutron_lib.tests._tools`` module is now public and named + ``tools``. + - The ``sqlalchemytypes`` module is now available in ``neutron_lib.db``.