Cleaning up _model_query and adding SG attribs

RM11464
This commit is contained in:
Josh Conant
2015-03-18 14:41:11 +00:00
parent f361b92f73
commit 367e0fbb65
4 changed files with 136 additions and 70 deletions

View File

@@ -1,7 +1,6 @@
language: python
python:
- "2.7"
- "2.6"
before_install:
- "export NO_EVENTLET=1"
install:

View File

@@ -25,9 +25,11 @@ from oslo_log import log as logging
from sqlalchemy import event
from sqlalchemy import func as sql_func
from sqlalchemy import and_, asc, desc, orm, or_, not_
from sqlalchemy.orm import class_mapper
from quark.db import models
from quark import network_strategy
from quark import protocols
STRATEGY = network_strategy.STRATEGY
@@ -68,65 +70,36 @@ def _listify(filters):
filters[key] = listified
def _model_attrs(model):
model_map = class_mapper(model)
model_attrs = [x.key for x in model_map.column_attrs]
if "_cidr" in model_attrs:
model_attrs.append("cidr")
if "_deallocated" in model_attrs:
model_attrs.append("deallocated")
return model_attrs
def _model_query(context, model, filters, fields=None):
filters = filters or {}
model_filters = []
eq_filters = ["address", "cidr", "deallocated", "ip_version",
"mac_address_range_id"]
in_filters = ["device_id", "device_owner", "group_id", "id", "mac_address",
"name", "network_id", "segment_id", "subnet_id",
"used_by_tenant_id", "version"]
if filters.get("name"):
model_filters.append(model.name.in_(filters["name"]))
if filters.get("network_id"):
model_filters.append(model.network_id.in_(filters["network_id"]))
if filters.get("mac_address"):
model_filters.append(model.mac_address.in_(filters["mac_address"]))
if filters.get("segment_id"):
model_filters.append(model.segment_id.in_(filters["segment_id"]))
if filters.get("id"):
model_filters.append(model.id.in_(filters["id"]))
if filters.get("group_id"):
model_filters.append(model.group_id.in_(filters["group_id"]))
if filters.get("reuse_after"):
reuse_after = filters["reuse_after"]
reuse = (timeutils.utcnow() -
datetime.timedelta(seconds=reuse_after))
model_filters.append(model.deallocated_at <= reuse)
if filters.get("subnet_id"):
model_filters.append(model.subnet_id.in_(filters["subnet_id"]))
if filters.get("deallocated"):
model_filters.append(model.deallocated == filters["deallocated"])
if filters.get("_deallocated") is not None:
if filters.get("_deallocated"):
model_filters.append(model._deallocated == 1)
else:
model_filters.append(model._deallocated != 1)
if filters.get("address"):
model_filters.append(model.address == filters["address"])
if filters.get("version"):
model_filters.append(model.version.in_(filters["version"]))
if filters.get("ip_version"):
model_filters.append(model.ip_version == filters["ip_version"])
if filters.get("ip_address"):
model_filters.append(model.address.in_(
[ip.ipv6().value for ip in filters["ip_address"]]))
if filters.get("mac_address_range_id"):
model_filters.append(model.mac_address_range_id ==
filters["mac_address_range_id"])
if filters.get("cidr"):
model_filters.append(model.cidr == filters["cidr"])
# Sanitize incoming filters to only attributes that exist in the model.
# NOTE: Filters for unusable attributes are silently dropped here.
# NOTE: When the filter key != attribute key, a conditional must be added
# here.
model_attrs = _model_attrs(model)
filters = {x: y for x, y in filters.items()
if x in model_attrs or
(x == "tenant_id" and model == models.IPAddress) or
(x == "ip_address" and model == models.IPAddress) or
(x == "reuse_after" and model in (models.IPAddress,
models.MacAddress))}
# Inject the tenant id if none is set. We don't need unqualified queries.
# This works even when a non-shared, other-tenant owned network is passed
@@ -135,20 +108,47 @@ def _model_query(context, model, filters, fields=None):
if not filters.get("tenant_id") and not context.is_admin:
filters["tenant_id"] = [context.tenant_id]
# Begin:Added for RM6299
if filters.get("used_by_tenant_id"):
model_filters.append(model.used_by_tenant_id.in_(
filters["used_by_tenant_id"]))
if model == models.SecurityGroupRule:
sg_rule_attribs = ["direction", "port_range_max", "port_range_min"]
eq_filters.extend(sg_rule_attribs)
if filters.get("tenant_id"):
if model == models.IPAddress:
model_filters.append(model.used_by_tenant_id.in_(
filters["tenant_id"]))
else:
model_filters.append(model.tenant_id.in_(filters["tenant_id"]))
# End: Added for RM6299
if filters.get("device_owner"):
model_filters.append(model.device_owner.in_(filters["device_owner"]))
for key, value in filters.items():
# This is mostly for unittests, as they're configured to send in None
if value is None:
continue
if key in in_filters:
model_type = getattr(model, key)
model_filters.append(model_type.in_(value))
elif key in eq_filters:
model_type = getattr(model, key)
model_filters.append(model_type == value)
elif key == "_deallocated":
if value:
model_filters.append(model._deallocated == 1)
else:
model_filters.append(model._deallocated != 1)
elif key == "ethertype":
etypes = []
for etype in value:
etypes.append(protocols.translate_ethertype(etype))
model_filters.append(model.ethertype.in_(etypes))
elif key == "ip_address":
model_filters.append(model.address.in_(
[ip.ipv6().value for ip in value]))
elif key == 'protocol':
pnums = []
for version in (protocols.PROTOCOLS_V4, protocols.PROTOCOLS_V6):
pnums.extend([y for x, y in version.items() if x in value])
model_filters.append(model.protocol.in_(pnums))
elif key == "reuse_after":
reuse = (timeutils.utcnow() -
datetime.timedelta(seconds=value))
model_filters.append(model.deallocated_at <= reuse)
elif key == "tenant_id":
if model == models.IPAddress:
model_filters.append(model.used_by_tenant_id.in_(value))
else:
model_filters.append(model.tenant_id.in_(value))
return model_filters

View File

@@ -15,11 +15,14 @@
import mock
import netaddr
from oslo_log import log as logging
from quark.db import api as db_api
from quark.db import models
from quark.tests.functional.base import BaseFunctionalTest
LOG = logging.getLogger(__name__)
class TestDBAPI(BaseFunctionalTest):
def setUp(self):
@@ -130,6 +133,70 @@ class TestDBAPI(BaseFunctionalTest):
except Exception as e:
self.fail("Expected no exceptions: %s" % e)
def test_model_query_with_IPAddress(self):
# NOTE: tenant_id filter will always be added
test_model = models.IPAddress
good_filter = {"network_id": [2]}
result = db_api._model_query(self.context, test_model, good_filter)
self.assertEqual(len(result), 2)
bad_filter = {"ethertype": "IPv4"}
result = db_api._model_query(self.context, test_model, bad_filter)
self.assertEqual(len(result), 1)
def test_model_query_with_MacAddress(self):
test_model = models.MacAddress
good_filter = {"deallocated": True}
result = db_api._model_query(self.context, test_model, good_filter)
self.assertEqual(len(result), 2)
bad_filter = {"protocol": "ICMP"}
result = db_api._model_query(self.context, test_model, bad_filter)
self.assertEqual(len(result), 1)
def test_model_query_with_Network(self):
test_model = models.Network
good_filter = {"name": ["BOB"]}
result = db_api._model_query(self.context, test_model, good_filter)
self.assertEqual(len(result), 2)
bad_filter = {"deallocated": True}
result = db_api._model_query(self.context, test_model, bad_filter)
self.assertEqual(len(result), 1)
def test_model_query_with_Port(self):
test_model = models.Port
good_filter = {"device_id": [123]}
result = db_api._model_query(self.context, test_model, good_filter)
self.assertEqual(len(result), 2)
bad_filter = {"not_real": "BANANAS"}
result = db_api._model_query(self.context, test_model, bad_filter)
self.assertEqual(len(result), 1)
def test_model_query_with_SecurityGroup(self):
test_model = models.SecurityGroup
good_filter = {"name": ["Abraham Lincoln"]}
result = db_api._model_query(self.context, test_model, good_filter)
self.assertEqual(len(result), 2)
bad_filter = {"segment_id": [123]}
result = db_api._model_query(self.context, test_model, bad_filter)
self.assertEqual(len(result), 1)
def test_model_query_with_SecurityGroupRule(self):
test_model = models.SecurityGroupRule
good_filter = {"ethertype": ["IPv4"]}
result = db_api._model_query(self.context, test_model, good_filter)
self.assertEqual(len(result), 2)
bad_filter = {"made_up": "Moon Landing"}
result = db_api._model_query(self.context, test_model, bad_filter)
self.assertEqual(len(result), 1)
def test_model_query_with_Subnet(self):
test_model = models.Subnet
good_filter = {"network_id": [42]}
result = db_api._model_query(self.context, test_model, good_filter)
self.assertEqual(len(result), 2)
bad_filter = {"subnet_id": [123]}
result = db_api._model_query(self.context, test_model, bad_filter)
self.assertEqual(len(result), 1)
def test_port_associate_ip(self):
self.context.session.add = mock.Mock()
mock_ports = [models.Port(id=str(x), network_id="2", ip_addresses=[])

View File

@@ -1,5 +1,5 @@
[tox]
envlist = py26,py27,flake8
envlist = py27,flake8
[testenv]
setenv = VIRTUAL_ENV={envdir}