diff --git a/.travis.yml b/.travis.yml index 84b4317..9c49f8b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,6 @@ language: python python: - "2.7" - - "2.6" before_install: - "export NO_EVENTLET=1" install: diff --git a/quark/db/api.py b/quark/db/api.py index eb6ada7..c3e4a63 100644 --- a/quark/db/api.py +++ b/quark/db/api.py @@ -25,9 +25,11 @@ from oslo_log import log as logging from sqlalchemy import event from sqlalchemy import func as sql_func from sqlalchemy import and_, asc, desc, orm, or_, not_ +from sqlalchemy.orm import class_mapper from quark.db import models from quark import network_strategy +from quark import protocols STRATEGY = network_strategy.STRATEGY @@ -68,65 +70,36 @@ def _listify(filters): filters[key] = listified +def _model_attrs(model): + model_map = class_mapper(model) + model_attrs = [x.key for x in model_map.column_attrs] + if "_cidr" in model_attrs: + model_attrs.append("cidr") + if "_deallocated" in model_attrs: + model_attrs.append("deallocated") + return model_attrs + + def _model_query(context, model, filters, fields=None): filters = filters or {} model_filters = [] + eq_filters = ["address", "cidr", "deallocated", "ip_version", + "mac_address_range_id"] + in_filters = ["device_id", "device_owner", "group_id", "id", "mac_address", + "name", "network_id", "segment_id", "subnet_id", + "used_by_tenant_id", "version"] - if filters.get("name"): - model_filters.append(model.name.in_(filters["name"])) - - if filters.get("network_id"): - model_filters.append(model.network_id.in_(filters["network_id"])) - - if filters.get("mac_address"): - model_filters.append(model.mac_address.in_(filters["mac_address"])) - - if filters.get("segment_id"): - model_filters.append(model.segment_id.in_(filters["segment_id"])) - - if filters.get("id"): - model_filters.append(model.id.in_(filters["id"])) - - if filters.get("group_id"): - model_filters.append(model.group_id.in_(filters["group_id"])) - - if filters.get("reuse_after"): - reuse_after = filters["reuse_after"] - reuse = (timeutils.utcnow() - - datetime.timedelta(seconds=reuse_after)) - model_filters.append(model.deallocated_at <= reuse) - - if filters.get("subnet_id"): - model_filters.append(model.subnet_id.in_(filters["subnet_id"])) - - if filters.get("deallocated"): - model_filters.append(model.deallocated == filters["deallocated"]) - - if filters.get("_deallocated") is not None: - if filters.get("_deallocated"): - model_filters.append(model._deallocated == 1) - else: - model_filters.append(model._deallocated != 1) - - if filters.get("address"): - model_filters.append(model.address == filters["address"]) - - if filters.get("version"): - model_filters.append(model.version.in_(filters["version"])) - - if filters.get("ip_version"): - model_filters.append(model.ip_version == filters["ip_version"]) - - if filters.get("ip_address"): - model_filters.append(model.address.in_( - [ip.ipv6().value for ip in filters["ip_address"]])) - - if filters.get("mac_address_range_id"): - model_filters.append(model.mac_address_range_id == - filters["mac_address_range_id"]) - - if filters.get("cidr"): - model_filters.append(model.cidr == filters["cidr"]) + # Sanitize incoming filters to only attributes that exist in the model. + # NOTE: Filters for unusable attributes are silently dropped here. + # NOTE: When the filter key != attribute key, a conditional must be added + # here. + model_attrs = _model_attrs(model) + filters = {x: y for x, y in filters.items() + if x in model_attrs or + (x == "tenant_id" and model == models.IPAddress) or + (x == "ip_address" and model == models.IPAddress) or + (x == "reuse_after" and model in (models.IPAddress, + models.MacAddress))} # Inject the tenant id if none is set. We don't need unqualified queries. # This works even when a non-shared, other-tenant owned network is passed @@ -135,20 +108,47 @@ def _model_query(context, model, filters, fields=None): if not filters.get("tenant_id") and not context.is_admin: filters["tenant_id"] = [context.tenant_id] - # Begin:Added for RM6299 - if filters.get("used_by_tenant_id"): - model_filters.append(model.used_by_tenant_id.in_( - filters["used_by_tenant_id"])) + if model == models.SecurityGroupRule: + sg_rule_attribs = ["direction", "port_range_max", "port_range_min"] + eq_filters.extend(sg_rule_attribs) - if filters.get("tenant_id"): - if model == models.IPAddress: - model_filters.append(model.used_by_tenant_id.in_( - filters["tenant_id"])) - else: - model_filters.append(model.tenant_id.in_(filters["tenant_id"])) - # End: Added for RM6299 - if filters.get("device_owner"): - model_filters.append(model.device_owner.in_(filters["device_owner"])) + for key, value in filters.items(): + # This is mostly for unittests, as they're configured to send in None + if value is None: + continue + if key in in_filters: + model_type = getattr(model, key) + model_filters.append(model_type.in_(value)) + elif key in eq_filters: + model_type = getattr(model, key) + model_filters.append(model_type == value) + elif key == "_deallocated": + if value: + model_filters.append(model._deallocated == 1) + else: + model_filters.append(model._deallocated != 1) + elif key == "ethertype": + etypes = [] + for etype in value: + etypes.append(protocols.translate_ethertype(etype)) + model_filters.append(model.ethertype.in_(etypes)) + elif key == "ip_address": + model_filters.append(model.address.in_( + [ip.ipv6().value for ip in value])) + elif key == 'protocol': + pnums = [] + for version in (protocols.PROTOCOLS_V4, protocols.PROTOCOLS_V6): + pnums.extend([y for x, y in version.items() if x in value]) + model_filters.append(model.protocol.in_(pnums)) + elif key == "reuse_after": + reuse = (timeutils.utcnow() - + datetime.timedelta(seconds=value)) + model_filters.append(model.deallocated_at <= reuse) + elif key == "tenant_id": + if model == models.IPAddress: + model_filters.append(model.used_by_tenant_id.in_(value)) + else: + model_filters.append(model.tenant_id.in_(value)) return model_filters diff --git a/quark/tests/test_db_api.py b/quark/tests/test_db_api.py index dff6821..bfb6e3e 100644 --- a/quark/tests/test_db_api.py +++ b/quark/tests/test_db_api.py @@ -15,11 +15,14 @@ import mock import netaddr +from oslo_log import log as logging from quark.db import api as db_api from quark.db import models from quark.tests.functional.base import BaseFunctionalTest +LOG = logging.getLogger(__name__) + class TestDBAPI(BaseFunctionalTest): def setUp(self): @@ -130,6 +133,70 @@ class TestDBAPI(BaseFunctionalTest): except Exception as e: self.fail("Expected no exceptions: %s" % e) + def test_model_query_with_IPAddress(self): + # NOTE: tenant_id filter will always be added + test_model = models.IPAddress + good_filter = {"network_id": [2]} + result = db_api._model_query(self.context, test_model, good_filter) + self.assertEqual(len(result), 2) + bad_filter = {"ethertype": "IPv4"} + result = db_api._model_query(self.context, test_model, bad_filter) + self.assertEqual(len(result), 1) + + def test_model_query_with_MacAddress(self): + test_model = models.MacAddress + good_filter = {"deallocated": True} + result = db_api._model_query(self.context, test_model, good_filter) + self.assertEqual(len(result), 2) + bad_filter = {"protocol": "ICMP"} + result = db_api._model_query(self.context, test_model, bad_filter) + self.assertEqual(len(result), 1) + + def test_model_query_with_Network(self): + test_model = models.Network + good_filter = {"name": ["BOB"]} + result = db_api._model_query(self.context, test_model, good_filter) + self.assertEqual(len(result), 2) + bad_filter = {"deallocated": True} + result = db_api._model_query(self.context, test_model, bad_filter) + self.assertEqual(len(result), 1) + + def test_model_query_with_Port(self): + test_model = models.Port + good_filter = {"device_id": [123]} + result = db_api._model_query(self.context, test_model, good_filter) + self.assertEqual(len(result), 2) + bad_filter = {"not_real": "BANANAS"} + result = db_api._model_query(self.context, test_model, bad_filter) + self.assertEqual(len(result), 1) + + def test_model_query_with_SecurityGroup(self): + test_model = models.SecurityGroup + good_filter = {"name": ["Abraham Lincoln"]} + result = db_api._model_query(self.context, test_model, good_filter) + self.assertEqual(len(result), 2) + bad_filter = {"segment_id": [123]} + result = db_api._model_query(self.context, test_model, bad_filter) + self.assertEqual(len(result), 1) + + def test_model_query_with_SecurityGroupRule(self): + test_model = models.SecurityGroupRule + good_filter = {"ethertype": ["IPv4"]} + result = db_api._model_query(self.context, test_model, good_filter) + self.assertEqual(len(result), 2) + bad_filter = {"made_up": "Moon Landing"} + result = db_api._model_query(self.context, test_model, bad_filter) + self.assertEqual(len(result), 1) + + def test_model_query_with_Subnet(self): + test_model = models.Subnet + good_filter = {"network_id": [42]} + result = db_api._model_query(self.context, test_model, good_filter) + self.assertEqual(len(result), 2) + bad_filter = {"subnet_id": [123]} + result = db_api._model_query(self.context, test_model, bad_filter) + self.assertEqual(len(result), 1) + def test_port_associate_ip(self): self.context.session.add = mock.Mock() mock_ports = [models.Port(id=str(x), network_id="2", ip_addresses=[]) diff --git a/tox.ini b/tox.ini index 461be94..21262cd 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py26,py27,flake8 +envlist = py27,flake8 [testenv] setenv = VIRTUAL_ENV={envdir}