Cleaning up _model_query and adding SG attribs

RM11464
This commit is contained in:
Josh Conant
2015-03-18 14:41:11 +00:00
parent f361b92f73
commit 367e0fbb65
4 changed files with 136 additions and 70 deletions

View File

@@ -1,7 +1,6 @@
language: python language: python
python: python:
- "2.7" - "2.7"
- "2.6"
before_install: before_install:
- "export NO_EVENTLET=1" - "export NO_EVENTLET=1"
install: install:

View File

@@ -25,9 +25,11 @@ from oslo_log import log as logging
from sqlalchemy import event from sqlalchemy import event
from sqlalchemy import func as sql_func from sqlalchemy import func as sql_func
from sqlalchemy import and_, asc, desc, orm, or_, not_ from sqlalchemy import and_, asc, desc, orm, or_, not_
from sqlalchemy.orm import class_mapper
from quark.db import models from quark.db import models
from quark import network_strategy from quark import network_strategy
from quark import protocols
STRATEGY = network_strategy.STRATEGY STRATEGY = network_strategy.STRATEGY
@@ -68,65 +70,36 @@ def _listify(filters):
filters[key] = listified filters[key] = listified
def _model_attrs(model):
model_map = class_mapper(model)
model_attrs = [x.key for x in model_map.column_attrs]
if "_cidr" in model_attrs:
model_attrs.append("cidr")
if "_deallocated" in model_attrs:
model_attrs.append("deallocated")
return model_attrs
def _model_query(context, model, filters, fields=None): def _model_query(context, model, filters, fields=None):
filters = filters or {} filters = filters or {}
model_filters = [] model_filters = []
eq_filters = ["address", "cidr", "deallocated", "ip_version",
"mac_address_range_id"]
in_filters = ["device_id", "device_owner", "group_id", "id", "mac_address",
"name", "network_id", "segment_id", "subnet_id",
"used_by_tenant_id", "version"]
if filters.get("name"): # Sanitize incoming filters to only attributes that exist in the model.
model_filters.append(model.name.in_(filters["name"])) # NOTE: Filters for unusable attributes are silently dropped here.
# NOTE: When the filter key != attribute key, a conditional must be added
if filters.get("network_id"): # here.
model_filters.append(model.network_id.in_(filters["network_id"])) model_attrs = _model_attrs(model)
filters = {x: y for x, y in filters.items()
if filters.get("mac_address"): if x in model_attrs or
model_filters.append(model.mac_address.in_(filters["mac_address"])) (x == "tenant_id" and model == models.IPAddress) or
(x == "ip_address" and model == models.IPAddress) or
if filters.get("segment_id"): (x == "reuse_after" and model in (models.IPAddress,
model_filters.append(model.segment_id.in_(filters["segment_id"])) models.MacAddress))}
if filters.get("id"):
model_filters.append(model.id.in_(filters["id"]))
if filters.get("group_id"):
model_filters.append(model.group_id.in_(filters["group_id"]))
if filters.get("reuse_after"):
reuse_after = filters["reuse_after"]
reuse = (timeutils.utcnow() -
datetime.timedelta(seconds=reuse_after))
model_filters.append(model.deallocated_at <= reuse)
if filters.get("subnet_id"):
model_filters.append(model.subnet_id.in_(filters["subnet_id"]))
if filters.get("deallocated"):
model_filters.append(model.deallocated == filters["deallocated"])
if filters.get("_deallocated") is not None:
if filters.get("_deallocated"):
model_filters.append(model._deallocated == 1)
else:
model_filters.append(model._deallocated != 1)
if filters.get("address"):
model_filters.append(model.address == filters["address"])
if filters.get("version"):
model_filters.append(model.version.in_(filters["version"]))
if filters.get("ip_version"):
model_filters.append(model.ip_version == filters["ip_version"])
if filters.get("ip_address"):
model_filters.append(model.address.in_(
[ip.ipv6().value for ip in filters["ip_address"]]))
if filters.get("mac_address_range_id"):
model_filters.append(model.mac_address_range_id ==
filters["mac_address_range_id"])
if filters.get("cidr"):
model_filters.append(model.cidr == filters["cidr"])
# Inject the tenant id if none is set. We don't need unqualified queries. # Inject the tenant id if none is set. We don't need unqualified queries.
# This works even when a non-shared, other-tenant owned network is passed # This works even when a non-shared, other-tenant owned network is passed
@@ -135,20 +108,47 @@ def _model_query(context, model, filters, fields=None):
if not filters.get("tenant_id") and not context.is_admin: if not filters.get("tenant_id") and not context.is_admin:
filters["tenant_id"] = [context.tenant_id] filters["tenant_id"] = [context.tenant_id]
# Begin:Added for RM6299 if model == models.SecurityGroupRule:
if filters.get("used_by_tenant_id"): sg_rule_attribs = ["direction", "port_range_max", "port_range_min"]
model_filters.append(model.used_by_tenant_id.in_( eq_filters.extend(sg_rule_attribs)
filters["used_by_tenant_id"]))
if filters.get("tenant_id"): for key, value in filters.items():
if model == models.IPAddress: # This is mostly for unittests, as they're configured to send in None
model_filters.append(model.used_by_tenant_id.in_( if value is None:
filters["tenant_id"])) continue
if key in in_filters:
model_type = getattr(model, key)
model_filters.append(model_type.in_(value))
elif key in eq_filters:
model_type = getattr(model, key)
model_filters.append(model_type == value)
elif key == "_deallocated":
if value:
model_filters.append(model._deallocated == 1)
else: else:
model_filters.append(model.tenant_id.in_(filters["tenant_id"])) model_filters.append(model._deallocated != 1)
# End: Added for RM6299 elif key == "ethertype":
if filters.get("device_owner"): etypes = []
model_filters.append(model.device_owner.in_(filters["device_owner"])) for etype in value:
etypes.append(protocols.translate_ethertype(etype))
model_filters.append(model.ethertype.in_(etypes))
elif key == "ip_address":
model_filters.append(model.address.in_(
[ip.ipv6().value for ip in value]))
elif key == 'protocol':
pnums = []
for version in (protocols.PROTOCOLS_V4, protocols.PROTOCOLS_V6):
pnums.extend([y for x, y in version.items() if x in value])
model_filters.append(model.protocol.in_(pnums))
elif key == "reuse_after":
reuse = (timeutils.utcnow() -
datetime.timedelta(seconds=value))
model_filters.append(model.deallocated_at <= reuse)
elif key == "tenant_id":
if model == models.IPAddress:
model_filters.append(model.used_by_tenant_id.in_(value))
else:
model_filters.append(model.tenant_id.in_(value))
return model_filters return model_filters

View File

@@ -15,11 +15,14 @@
import mock import mock
import netaddr import netaddr
from oslo_log import log as logging
from quark.db import api as db_api from quark.db import api as db_api
from quark.db import models from quark.db import models
from quark.tests.functional.base import BaseFunctionalTest from quark.tests.functional.base import BaseFunctionalTest
LOG = logging.getLogger(__name__)
class TestDBAPI(BaseFunctionalTest): class TestDBAPI(BaseFunctionalTest):
def setUp(self): def setUp(self):
@@ -130,6 +133,70 @@ class TestDBAPI(BaseFunctionalTest):
except Exception as e: except Exception as e:
self.fail("Expected no exceptions: %s" % e) self.fail("Expected no exceptions: %s" % e)
def test_model_query_with_IPAddress(self):
# NOTE: tenant_id filter will always be added
test_model = models.IPAddress
good_filter = {"network_id": [2]}
result = db_api._model_query(self.context, test_model, good_filter)
self.assertEqual(len(result), 2)
bad_filter = {"ethertype": "IPv4"}
result = db_api._model_query(self.context, test_model, bad_filter)
self.assertEqual(len(result), 1)
def test_model_query_with_MacAddress(self):
test_model = models.MacAddress
good_filter = {"deallocated": True}
result = db_api._model_query(self.context, test_model, good_filter)
self.assertEqual(len(result), 2)
bad_filter = {"protocol": "ICMP"}
result = db_api._model_query(self.context, test_model, bad_filter)
self.assertEqual(len(result), 1)
def test_model_query_with_Network(self):
test_model = models.Network
good_filter = {"name": ["BOB"]}
result = db_api._model_query(self.context, test_model, good_filter)
self.assertEqual(len(result), 2)
bad_filter = {"deallocated": True}
result = db_api._model_query(self.context, test_model, bad_filter)
self.assertEqual(len(result), 1)
def test_model_query_with_Port(self):
test_model = models.Port
good_filter = {"device_id": [123]}
result = db_api._model_query(self.context, test_model, good_filter)
self.assertEqual(len(result), 2)
bad_filter = {"not_real": "BANANAS"}
result = db_api._model_query(self.context, test_model, bad_filter)
self.assertEqual(len(result), 1)
def test_model_query_with_SecurityGroup(self):
test_model = models.SecurityGroup
good_filter = {"name": ["Abraham Lincoln"]}
result = db_api._model_query(self.context, test_model, good_filter)
self.assertEqual(len(result), 2)
bad_filter = {"segment_id": [123]}
result = db_api._model_query(self.context, test_model, bad_filter)
self.assertEqual(len(result), 1)
def test_model_query_with_SecurityGroupRule(self):
test_model = models.SecurityGroupRule
good_filter = {"ethertype": ["IPv4"]}
result = db_api._model_query(self.context, test_model, good_filter)
self.assertEqual(len(result), 2)
bad_filter = {"made_up": "Moon Landing"}
result = db_api._model_query(self.context, test_model, bad_filter)
self.assertEqual(len(result), 1)
def test_model_query_with_Subnet(self):
test_model = models.Subnet
good_filter = {"network_id": [42]}
result = db_api._model_query(self.context, test_model, good_filter)
self.assertEqual(len(result), 2)
bad_filter = {"subnet_id": [123]}
result = db_api._model_query(self.context, test_model, bad_filter)
self.assertEqual(len(result), 1)
def test_port_associate_ip(self): def test_port_associate_ip(self):
self.context.session.add = mock.Mock() self.context.session.add = mock.Mock()
mock_ports = [models.Port(id=str(x), network_id="2", ip_addresses=[]) mock_ports = [models.Port(id=str(x), network_id="2", ip_addresses=[])

View File

@@ -1,5 +1,5 @@
[tox] [tox]
envlist = py26,py27,flake8 envlist = py27,flake8
[testenv] [testenv]
setenv = VIRTUAL_ENV={envdir} setenv = VIRTUAL_ENV={envdir}