Migrating "Classification Groups"

This patch removes the "Classification Groups" from
Neutron Classifier's migrations.

Reasoning:
For other extensions to use Neutron Classifier's
"Classification Groups" they will need to make a foreign
key association to the "classification group's" id.

When the DB migrations are being run, the "Classification
Group" table has to exist before it can be referenced.
As Neutron runs it's own DB migrations before other extensions,
this means that the QoS extensions would not be able to use
"Classifications".

This patch is the first of 2 patches, with the 2nd patch
inserting the "Classification Groups" into Neutron's migrations.
Link to 2nd patch: https://review.opendev.org/#/c/636333/

Note: Only the L2 Agent is currently being imported.

Recent changes include:
 - Change functional tests to correspond to the class migrations
 - Change neutron_classifier/objects/classifications to classification
   to correspond to the neutron/objects naming convention.
 - Change all objects imported from neutron.objects.classification
   to be imported as n_class_obj.
 - Change all objects imported from neutron_classifier/objects/
   classification to be imported as class_obj.

Depends-On: https://review.opendev.org/636333
Change-Id: Ibf5424b643b027da1fd03780f7ef81b970400c28
This commit is contained in:
Sara Nierodzik 2019-02-12 13:13:20 +00:00
parent 1603f200d7
commit bb2b5673b1
22 changed files with 692 additions and 398 deletions

View File

@ -13,19 +13,19 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
from neutron.objects import classification as n_class_obj
from neutron_classifier.objects import classification as class_obj
from neutron_classifier.objects import classifications as cs COMMON_FIELDS = n_class_obj.ClassificationBase.fields.keys()
FIELDS_IPV4 = list(set(class_obj.IPV4Classification.fields.keys()) -
COMMON_FIELDS = cs.ClassificationBase.fields.keys()
FIELDS_IPV4 = list(set(cs.IPV4Classification.fields.keys()) -
set(COMMON_FIELDS)) set(COMMON_FIELDS))
FIELDS_IPV6 = list(set(cs.IPV6Classification.fields.keys()) - FIELDS_IPV6 = list(set(class_obj.IPV6Classification.fields.keys()) -
set(COMMON_FIELDS)) set(COMMON_FIELDS))
FIELDS_TCP = list(set(cs.TCPClassification.fields.keys()) - FIELDS_TCP = list(set(class_obj.TCPClassification.fields.keys()) -
set(COMMON_FIELDS)) set(COMMON_FIELDS))
FIELDS_UDP = list(set(cs.UDPClassification.fields.keys()) - FIELDS_UDP = list(set(class_obj.UDPClassification.fields.keys()) -
set(COMMON_FIELDS)) set(COMMON_FIELDS))
FIELDS_ETHERNET = list(set(cs.EthernetClassification.fields.keys()) - FIELDS_ETHERNET = list(set(class_obj.EthernetClassification.fields.keys()) -
set(COMMON_FIELDS)) set(COMMON_FIELDS))
@ -34,3 +34,10 @@ SUPPORTED_FIELDS = {'ipv4': FIELDS_IPV4,
'tcp': FIELDS_TCP, 'tcp': FIELDS_TCP,
'udp': FIELDS_UDP, 'udp': FIELDS_UDP,
'ethernet': FIELDS_ETHERNET} 'ethernet': FIELDS_ETHERNET}
# Method names for receiving classifications
PRECOMMIT_POSTFIX = '_precommit'
CREATE_CLASS = 'create_classification'
CREATE_CLASS_PRECOMMIT = CREATE_CLASS + PRECOMMIT_POSTFIX
DELETE_CLASS = 'delete_classification'
DELETE_CLASS_PRECOMMIT = DELETE_CLASS + PRECOMMIT_POSTFIX

View File

@ -12,6 +12,7 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
from neutron.objects import classification as cs_base
from neutron_classifier.common import eth_validators from neutron_classifier.common import eth_validators
from neutron_classifier.common import exceptions from neutron_classifier.common import exceptions
from neutron_classifier.common import ipv4_validators from neutron_classifier.common import ipv4_validators
@ -19,7 +20,7 @@ from neutron_classifier.common import ipv6_validators
from neutron_classifier.common import tcp_validators from neutron_classifier.common import tcp_validators
from neutron_classifier.common import udp_validators from neutron_classifier.common import udp_validators
from neutron_classifier.db import models from neutron_classifier.db import models
from neutron_classifier.objects import classifications from neutron_classifier.objects import classification as class_obj
from neutron_lib.db import api as db_api from neutron_lib.db import api as db_api
@ -33,9 +34,9 @@ type_validators['udp'] = udp_validators.validators_dict
def check_valid_classifications(context, cs): def check_valid_classifications(context, cs):
for c_id in cs: for c_id in cs:
c_model = classifications.ClassificationBase c_model = cs_base.ClassificationBase
c = c_model.get_object(context, id=c_id) c = c_model.get_object(context, id=c_id)
c_type_clas = classifications.CLASS_MAP[c.c_type] c_type_clas = class_obj.CLASS_MAP[c.c_type]
classification = c_type_clas.get_object(context, id=c_id) classification = c_type_clas.get_object(context, id=c_id)
if not classification or (classification.id != c_id): if not classification or (classification.id != c_id):
raise exceptions.InvalidClassificationId() raise exceptions.InvalidClassificationId()
@ -55,12 +56,12 @@ def check_can_delete_classification_group(context, cg_id):
classification group, meaning is already mapped to a parent classification classification group, meaning is already mapped to a parent classification
group. In that case we cannot delete it and will raise an exception. group. In that case we cannot delete it and will raise an exception.
""" """
cgs = classifications.ClassificationGroup.get_objects(context) cgs = cs_base.ClassificationGroup.get_objects(context)
for cg in cgs: for cg in cgs:
with db_api.CONTEXT_WRITER.using(context): with db_api.CONTEXT_WRITER.using(context):
cg_obj = classifications.ClassificationGroup.get_object(context, cg_obj = cs_base.ClassificationGroup.get_object(context,
id=cg.id) id=cg.id)
mapped_cgs = classifications._get_mapped_classification_groups( mapped_cgs = class_obj._get_mapped_classification_groups(
context, cg_obj) context, cg_obj)
if cg_id in [mcg.id for mcg in mapped_cgs]: if cg_id in [mcg.id for mcg in mapped_cgs]:
raise exceptions.ConsumedClassificationGroup() raise exceptions.ConsumedClassificationGroup()

View File

@ -18,10 +18,11 @@ from oslo_utils import uuidutils
from neutron_lib.db import api as db_api from neutron_lib.db import api as db_api
from neutron.objects import base as base_obj from neutron.objects import base as base_obj
from neutron.objects import classification as n_class_obj
from neutron_classifier.common import exceptions from neutron_classifier.common import exceptions
from neutron_classifier.common import validators from neutron_classifier.common import validators
from neutron_classifier.objects import classifications from neutron_classifier.objects import classification as class_obj
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -51,7 +52,7 @@ class TrafficClassificationGroupPlugin(object):
db_dict = details db_dict = details
if 'tenant_id' in details: if 'tenant_id' in details:
del details['tenant_id'] del details['tenant_id']
cg = classifications.ClassificationGroup(context, **details) cg = n_class_obj.ClassificationGroup(context, **details)
with db_api.CONTEXT_WRITER.using(context): with db_api.CONTEXT_WRITER.using(context):
cg.create() cg.create()
@ -60,7 +61,7 @@ class TrafficClassificationGroupPlugin(object):
with db_api.CONTEXT_WRITER.using(context): with db_api.CONTEXT_WRITER.using(context):
if c_flag: if c_flag:
for cl in mappings['c_ids']: for cl in mappings['c_ids']:
cg_c_mapping = classifications.CGToClassificationMapping( cg_c_mapping = n_class_obj.CGToClassificationMapping(
context, context,
container_cg_id=cg.id, container_cg_id=cg.id,
stored_classification_id=cl) stored_classification_id=cl)
@ -68,7 +69,7 @@ class TrafficClassificationGroupPlugin(object):
if cg_flag: if cg_flag:
for cg_id in mappings['cg_ids']: for cg_id in mappings['cg_ids']:
cg_cg_mapping =\ cg_cg_mapping =\
classifications.CGToClassificationGroupMapping( n_class_obj.CGToClassificationGroupMapping(
context, context,
container_cg_id=cg.id, container_cg_id=cg.id,
stored_cg_id=cg_id stored_cg_id=cg_id
@ -84,7 +85,7 @@ class TrafficClassificationGroupPlugin(object):
def delete_classification_group(self, context, classification_group_id): def delete_classification_group(self, context, classification_group_id):
if validators.check_can_delete_classification_group( if validators.check_can_delete_classification_group(
context, classification_group_id): context, classification_group_id):
cg = classifications.ClassificationGroup.get_object( cg = n_class_obj.ClassificationGroup.get_object(
context, id=classification_group_id) context, id=classification_group_id)
with db_api.CONTEXT_WRITER.using(context): with db_api.CONTEXT_WRITER.using(context):
cg.delete() cg.delete()
@ -98,7 +99,7 @@ class TrafficClassificationGroupPlugin(object):
if key not in valid_keys: if key not in valid_keys:
raise exceptions.InvalidUpdateRequest() raise exceptions.InvalidUpdateRequest()
with db_api.CONTEXT_WRITER.using(context): with db_api.CONTEXT_WRITER.using(context):
cg = classifications.ClassificationGroup.update_object( cg = n_class_obj.ClassificationGroup.update_object(
context, fields_to_update, id=classification_group_id) context, fields_to_update, id=classification_group_id)
db_dict = self._make_db_dict(cg) db_dict = self._make_db_dict(cg)
return db_dict return db_dict
@ -140,12 +141,12 @@ class TrafficClassificationGroupPlugin(object):
def get_classification_group(self, context, classification_group_id, def get_classification_group(self, context, classification_group_id,
fields=None): fields=None):
with db_api.CONTEXT_WRITER.using(context): with db_api.CONTEXT_WRITER.using(context):
cg = classifications.ClassificationGroup.get_object( cg = n_class_obj.ClassificationGroup.get_object(
context, id=classification_group_id) context, id=classification_group_id)
db_dict = self._make_db_dict(cg) db_dict = self._make_db_dict(cg)
mapped_cs = classifications._get_mapped_classifications(context, mapped_cs = class_obj._get_mapped_classifications(context,
cg) cg)
mapped_cgs = classifications._get_mapped_classification_groups( mapped_cgs = class_obj._get_mapped_classification_groups(
context, cg) context, cg)
c_dict = self._make_c_dicts(mapped_cs) c_dict = self._make_c_dicts(mapped_cs)
cg_dict = self._make_db_dicts(mapped_cgs) cg_dict = self._make_db_dicts(mapped_cgs)
@ -157,7 +158,7 @@ class TrafficClassificationGroupPlugin(object):
marker=None, page_reverse=False, marker=None, page_reverse=False,
filters=None, fields=None): filters=None, fields=None):
pager = base_obj.Pager(sorts, limit, page_reverse, marker) pager = base_obj.Pager(sorts, limit, page_reverse, marker)
cgs = classifications.ClassificationGroup.get_objects(context, cgs = n_class_obj.ClassificationGroup.get_objects(context,
_pager=pager) pager=pager)
db_dict = self._make_db_dicts(cgs) db_dict = self._make_db_dicts(cgs)
return db_dict return db_dict

View File

@ -1,5 +1,3 @@
# Copyright 2017 Intel Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may # Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain # not use this file except in compliance with the License. You may obtain
# a copy of the License at # a copy of the License at
@ -30,63 +28,6 @@ down_revision = 'start_neutron_classifier'
def upgrade(): def upgrade():
op.create_table(
'classification_groups',
sa.Column('id', sa.String(length=36), primary_key=True),
sa.Column('name', sa.String(length=255)),
sa.Column('description', sa.String(length=255)),
sa.Column('project_id', sa.String(length=255),
index=True),
sa.Column('shared', sa.Boolean()),
sa.Column('operator', sa.Enum("AND", "OR", name="operator_types"),
nullable=False))
op.create_table(
'classificationgrouprbacs',
sa.Column('id', sa.String(length=36), primary_key=True,
nullable=False),
sa.Column('project_id', sa.String(length=255)),
sa.Column('target_tenant', sa.String(length=255),
nullable=False),
sa.Column('action', sa.String(length=255), nullable=False),
sa.Column('object_id', sa.String(length=36),
nullable=False),
sa.ForeignKeyConstraint(['object_id'],
['classification_groups.id'],
ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('target_tenant',
'object_id', 'action'))
op.create_index(op.f('ix_classificationgrouprbacs_project_id'),
'classificationgrouprbacs',
['project_id'], unique=False)
op.create_table(
'classifications',
sa.Column('id', sa.String(length=36), primary_key=True),
sa.Column('c_type', sa.String(length=36)),
sa.Column('name', sa.String(length=255)),
sa.Column('description', sa.String(length=255)),
sa.Column('negated', sa.Boolean()),
sa.Column('shared', sa.Boolean()),
sa.Column('project_id', sa.String(length=255),
index=True))
op.create_table(
'classification_group_to_classification_mappings',
sa.Column('container_cg_id', sa.String(length=36), sa.ForeignKey(
"classification_groups.id", ondelete="CASCADE"),
primary_key=True),
sa.Column('stored_classification_id', sa.String(length=36),
sa.ForeignKey("classifications.id"), primary_key=True))
op.create_table(
'classification_group_to_cg_mappings',
sa.Column('container_cg_id', sa.String(length=36), sa.ForeignKey(
"classification_groups.id", ondelete="CASCADE"),
primary_key=True),
sa.Column('stored_cg_id', sa.String(length=36), sa.ForeignKey(
"classification_groups.id"), primary_key=True))
op.create_table( op.create_table(
'ipv4_classifications', 'ipv4_classifications',

View File

@ -12,65 +12,14 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
from neutron_lib.db import model_base from neutron.db import classification as cs_db
from neutron_lib.db import model_query as mq from neutron_lib.db import model_query as mq
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import orm
class ClassificationGroup(model_base.BASEV2, model_base.HasId, class IPV4Classification(cs_db.ClassificationBase):
model_base.HasProject):
__tablename__ = 'classification_groups'
name = sa.Column(sa.String(255))
description = sa.Column(sa.String(255))
shared = sa.Column(sa.Boolean, default=False)
operator = sa.Column(sa.Enum('AND', 'OR'), default='AND', nullable=False)
classifications = orm.relationship(
"ClassificationBase", lazy="subquery",
secondary='classification_group_to_classification_mappings')
classification_groups = orm.relationship(
"ClassificationGroup", lazy="subquery",
secondary='classification_group_to_cg_mappings',
primaryjoin="ClassificationGroup.id=="
"CGToClassificationGroupMapping.container_cg_id",
secondaryjoin="ClassificationGroup.id=="
"CGToClassificationGroupMapping.stored_cg_id")
class CGToClassificationMapping(model_base.BASEV2):
__tablename__ = 'classification_group_to_classification_mappings'
container_cg_id = sa.Column(sa.String(36),
sa.ForeignKey('classification_groups.id',
ondelete='CASCADE'), primary_key=True)
classification = orm.relationship("ClassificationBase", lazy="subquery")
stored_classification_id = sa.Column(sa.String(36),
sa.ForeignKey('classifications.id'),
primary_key=True)
class CGToClassificationGroupMapping(model_base.BASEV2):
__tablename__ = 'classification_group_to_cg_mappings'
container_cg_id = sa.Column(sa.String(36),
sa.ForeignKey('classification_groups.id',
ondelete='CASCADE'), primary_key=True)
stored_cg_id = sa.Column(sa.String(36),
sa.ForeignKey('classification_groups.id'),
primary_key=True)
class ClassificationBase(model_base.HasId, model_base.HasProject,
model_base.BASEV2):
__tablename__ = 'classifications'
c_type = sa.Column(sa.String(36))
__mapper_args__ = {'polymorphic_on': c_type}
name = sa.Column(sa.String(255))
description = sa.Column(sa.String(255))
shared = sa.Column(sa.Boolean())
negated = sa.Column(sa.Boolean())
class IPV4Classification(ClassificationBase):
__tablename__ = 'ipv4_classifications' __tablename__ = 'ipv4_classifications'
__mapper_args__ = {'polymorphic_identity': 'ipv4'} __mapper_args__ = {'polymorphic_identity': 'ipv4'}
id = sa.Column(sa.String(36), sa.ForeignKey('classifications.id', id = sa.Column(sa.String(36), sa.ForeignKey('classifications.id',
@ -89,7 +38,7 @@ class IPV4Classification(ClassificationBase):
dst_addr = sa.Column(sa.String(19)) dst_addr = sa.Column(sa.String(19))
class IPV6Classification(ClassificationBase): class IPV6Classification(cs_db.ClassificationBase):
__tablename__ = 'ipv6_classifications' __tablename__ = 'ipv6_classifications'
__mapper_args__ = {'polymorphic_identity': 'ipv6'} __mapper_args__ = {'polymorphic_identity': 'ipv6'}
id = sa.Column(sa.String(36), sa.ForeignKey('classifications.id', id = sa.Column(sa.String(36), sa.ForeignKey('classifications.id',
@ -106,7 +55,7 @@ class IPV6Classification(ClassificationBase):
dst_addr = sa.Column(sa.String(49)) dst_addr = sa.Column(sa.String(49))
class EthernetClassification(ClassificationBase): class EthernetClassification(cs_db.ClassificationBase):
__tablename__ = 'ethernet_classifications' __tablename__ = 'ethernet_classifications'
__mapper_args__ = {'polymorphic_identity': 'ethernet'} __mapper_args__ = {'polymorphic_identity': 'ethernet'}
id = sa.Column(sa.String(36), sa.ForeignKey('classifications.id', id = sa.Column(sa.String(36), sa.ForeignKey('classifications.id',
@ -116,7 +65,7 @@ class EthernetClassification(ClassificationBase):
dst_addr = sa.Column(sa.String(17)) dst_addr = sa.Column(sa.String(17))
class UDPClassification(ClassificationBase): class UDPClassification(cs_db.ClassificationBase):
__tablename__ = 'udp_classifications' __tablename__ = 'udp_classifications'
__mapper_args__ = {'polymorphic_identity': 'udp'} __mapper_args__ = {'polymorphic_identity': 'udp'}
id = sa.Column(sa.String(36), sa.ForeignKey('classifications.id', id = sa.Column(sa.String(36), sa.ForeignKey('classifications.id',
@ -129,7 +78,7 @@ class UDPClassification(ClassificationBase):
length_max = sa.Column(sa.Integer()) length_max = sa.Column(sa.Integer())
class TCPClassification(ClassificationBase): class TCPClassification(cs_db.ClassificationBase):
__tablename__ = 'tcp_classifications' __tablename__ = 'tcp_classifications'
__mapper_args__ = {'polymorphic_identity': 'tcp'} __mapper_args__ = {'polymorphic_identity': 'tcp'}
id = sa.Column(sa.String(36), sa.ForeignKey('classifications.id', id = sa.Column(sa.String(36), sa.ForeignKey('classifications.id',
@ -147,7 +96,7 @@ class TCPClassification(ClassificationBase):
def _read_classification_group(context, id): def _read_classification_group(context, id):
"""Returns a classification group.""" """Returns a classification group."""
cg = mq.get_by_id(context, ClassificationGroup, id) cg = mq.get_by_id(context, cs_db.ClassificationGroup, id)
return cg return cg
@ -187,7 +136,8 @@ def _generate_dict_from_cg_db(model, fields=None):
def _read_all_classification_groups(plugin, context): def _read_all_classification_groups(plugin, context):
"""Returns all classification groups.""" """Returns all classification groups."""
class_group = plugin._get_collection(context, ClassificationGroup, class_group = plugin._get_collection(context,
cs_db.ClassificationGroup,
_generate_dict_from_cg_db) _generate_dict_from_cg_db)
return class_group return class_group

View File

@ -1,26 +0,0 @@
# Copyright 2017 Intel Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from neutron.db import rbac_db_models
from neutron_lib.db import model_base
class ClassificationGroupRBAC(rbac_db_models.RBACColumns, model_base.BASEV2):
"""RBAC table for classification groups."""
object_id = rbac_db_models._object_id_column('classification_groups.id')
object_type = 'classification_group'
def get_valid_actions(self):
return (rbac_db_models.ACCESS_SHARED)

View File

@ -13,5 +13,6 @@
def register_objects(): def register_objects():
# local import to avoid circular import failure # local import to avoid circular import failure
__import__('neutron_classifier.objects.classifications') __import__('neutron_classifier.objects.classification')
__import__('neutron_classifier.objects.classification_type') __import__('neutron_classifier.objects.classification_type')
__import__('neutron.objects.classification')

View File

@ -12,116 +12,17 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import abc
import six
from oslo_versionedobjects import base as obj_base from oslo_versionedobjects import base as obj_base
from oslo_versionedobjects import fields as obj_fields from oslo_versionedobjects import fields as obj_fields
from neutron.objects import base from neutron.objects import classification as cs_base
from neutron.objects import common_types
from neutron.objects import rbac_db
from neutron_lib.db import api as db_api from neutron_lib.db import api as db_api
from neutron_classifier.db import models from neutron_classifier.db import models
from neutron_classifier.db.rbac_db_models import ClassificationGroupRBAC
@obj_base.VersionedObjectRegistry.register @obj_base.VersionedObjectRegistry.register
class ClassificationGroup(rbac_db.NeutronRbacObject): class IPV4Classification(cs_base.ClassificationBase):
# Version 1.0: Initial version
VERSION = '1.0'
# required by RbacNeutronMetaclass
rbac_db_cls = ClassificationGroupRBAC
db_model = models.ClassificationGroup
fields = {
'id': common_types.UUIDField(),
'name': obj_fields.StringField(),
'description': obj_fields.StringField(),
'project_id': obj_fields.StringField(),
'shared': obj_fields.BooleanField(default=False),
'operator': obj_fields.EnumField(['AND', 'OR'], default='AND'),
}
fields_no_update = ['id', 'project_id']
@classmethod
def get_object(cls, context, **kwargs):
# We want to get the policy regardless of its tenant id. We'll make
# sure the tenant has permission to access the policy later on.
admin_context = context.elevated()
with db_api.autonested_transaction(admin_context.session):
obj = super(ClassificationGroup, cls).get_object(admin_context,
**kwargs)
if not obj or not cls.is_accessible(context, obj):
return
return obj
@classmethod
def get_bound_tenant_ids(cls, context, **kwargs):
# If we can return the policy regardless of tenant, we don't need
# to return the tenant id.
pass
@obj_base.VersionedObjectRegistry.register
class CGToClassificationMapping(base.NeutronDbObject):
VERSION = '1.0'
rbac_db_model = ClassificationGroupRBAC
db_model = models.CGToClassificationMapping
fields = {
'container_cg_id': common_types.UUIDField(),
'stored_classification_id': common_types.UUIDField()}
@obj_base.VersionedObjectRegistry.register
class CGToClassificationGroupMapping(base.NeutronDbObject):
VERSION = '1.0'
rbac_db_model = ClassificationGroupRBAC
db_model = models.CGToClassificationGroupMapping
fields = {
'container_cg_id': common_types.UUIDField(),
'stored_cg_id': common_types.UUIDField()
}
@six.add_metaclass(abc.ABCMeta)
class ClassificationBase(base.NeutronDbObject):
VERSION = '1.0'
db_model = models.ClassificationBase
fields = {
'id': common_types.UUIDField(),
'name': obj_fields.StringField(),
'description': obj_fields.StringField(),
'project_id': obj_fields.StringField(),
'shared': obj_fields.BooleanField(default=False),
'c_type': obj_fields.StringField(),
'negated': obj_fields.BooleanField(default=False),
}
fields_no_update = ['id', 'c_type']
@classmethod
def get_objects(cls, context, _pager=None, validate_filters=True,
**kwargs):
with db_api.autonested_transaction(context.session):
objects = super(ClassificationBase,
cls).get_objects(context, _pager,
validate_filters,
**kwargs)
return objects
@obj_base.VersionedObjectRegistry.register
class IPV4Classification(ClassificationBase):
VERSION = '1.0' VERSION = '1.0'
db_model = models.IPV4Classification db_model = models.IPV4Classification
@ -143,7 +44,7 @@ class IPV4Classification(ClassificationBase):
def create(self): def create(self):
with db_api.autonested_transaction(self.obj_context.session): with db_api.autonested_transaction(self.obj_context.session):
super(ClassificationBase, self).create() super(cs_base.ClassificationBase, self).create()
@classmethod @classmethod
def get_object(cls, context, **kwargs): def get_object(cls, context, **kwargs):
@ -155,7 +56,7 @@ class IPV4Classification(ClassificationBase):
@obj_base.VersionedObjectRegistry.register @obj_base.VersionedObjectRegistry.register
class IPV6Classification(ClassificationBase): class IPV6Classification(cs_base.ClassificationBase):
VERSION = '1.0' VERSION = '1.0'
db_model = models.IPV6Classification db_model = models.IPV6Classification
@ -175,7 +76,7 @@ class IPV6Classification(ClassificationBase):
def create(self): def create(self):
with db_api.autonested_transaction(self.obj_context.session): with db_api.autonested_transaction(self.obj_context.session):
super(ClassificationBase, self).create() super(cs_base.ClassificationBase, self).create()
@classmethod @classmethod
def get_object(cls, context, **kwargs): def get_object(cls, context, **kwargs):
@ -187,7 +88,7 @@ class IPV6Classification(ClassificationBase):
@obj_base.VersionedObjectRegistry.register @obj_base.VersionedObjectRegistry.register
class EthernetClassification(ClassificationBase): class EthernetClassification(cs_base.ClassificationBase):
VERSION = '1.0' VERSION = '1.0'
db_model = models.EthernetClassification db_model = models.EthernetClassification
@ -199,7 +100,7 @@ class EthernetClassification(ClassificationBase):
def create(self): def create(self):
with db_api.autonested_transaction(self.obj_context.session): with db_api.autonested_transaction(self.obj_context.session):
super(ClassificationBase, self).create() super(cs_base.ClassificationBase, self).create()
@classmethod @classmethod
def get_object(cls, context, **kwargs): def get_object(cls, context, **kwargs):
@ -211,7 +112,7 @@ class EthernetClassification(ClassificationBase):
@obj_base.VersionedObjectRegistry.register @obj_base.VersionedObjectRegistry.register
class UDPClassification(ClassificationBase): class UDPClassification(cs_base.ClassificationBase):
VERSION = '1.0' VERSION = '1.0'
db_model = models.UDPClassification db_model = models.UDPClassification
@ -226,7 +127,7 @@ class UDPClassification(ClassificationBase):
def create(self): def create(self):
with db_api.autonested_transaction(self.obj_context.session): with db_api.autonested_transaction(self.obj_context.session):
super(ClassificationBase, self).create() super(cs_base.ClassificationBase, self).create()
@classmethod @classmethod
def get_object(cls, context, **kwargs): def get_object(cls, context, **kwargs):
@ -238,7 +139,7 @@ class UDPClassification(ClassificationBase):
@obj_base.VersionedObjectRegistry.register @obj_base.VersionedObjectRegistry.register
class TCPClassification(ClassificationBase): class TCPClassification(cs_base.ClassificationBase):
VERSION = '1.0' VERSION = '1.0'
db_model = models.TCPClassification db_model = models.TCPClassification
@ -255,7 +156,7 @@ class TCPClassification(ClassificationBase):
def create(self): def create(self):
with db_api.autonested_transaction(self.obj_context.session): with db_api.autonested_transaction(self.obj_context.session):
super(ClassificationBase, self).create() super(cs_base.ClassificationBase, self).create()
@classmethod @classmethod
def get_object(cls, context, **kwargs): def get_object(cls, context, **kwargs):
@ -292,8 +193,10 @@ def _get_mapped_classification_groups(context, obj):
:param obj: ClassificationGroup object :param obj: ClassificationGroup object
:return: list of ClassificationGroup objects :return: list of ClassificationGroup objects
""" """
mapped_cgs = [ClassificationGroup._load_object(context, cg) for cg in mapped_cgs = [cs_base.ClassificationGroup._load_object(context,
models._read_classification_groups(context, obj.id)] cg)
for cg in models._read_classification_groups(context,
obj.id)]
return mapped_cgs return mapped_cgs

View File

@ -0,0 +1,127 @@
# Copyright 2019 Intel Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from oslo_log import log as logging
from neutron.api.rpc.callbacks import events as rpc_events
from neutron.api.rpc.callbacks.producer import registry as rpc_registry
from neutron.api.rpc.callbacks import resources
from neutron.api.rpc.handlers import resources_rpc
from neutron.objects import classification as n_class_obj
from neutron_classifier.common import constants as nc_consts
from neutron_classifier.db import models
from neutron_classifier.objects import classification as class_obj
from neutron_lib.db import api as db_api
from neutron_lib import rpc as n_rpc
import oslo_messaging
LOG = logging.getLogger(__name__)
class NeutronClassifierAdvertiserCallback(object):
"""Neutron Classifier RPC server"""
def __init__(self, adv):
self.target = oslo_messaging.Target(version='1.0')
self.adv = adv
def get_classification_group_mapping(self, context, **kwargs):
cg_id = kwargs['cg_id']
return self.adv._get_classification_group_mapping(context, cg_id)
def get_classification(self, context, **kwargs):
c_id = kwargs['c_id']
return self.adv._get_classification("classification", c_id,
context=context)
class ClassificationAdvertiser(object):
def __init__(self):
self.rpc_notifications_required = True
self._init_classification_topics()
rpc_registry.provide(self._get_classification,
n_class_obj.ClassificationBase.obj_name())
rpc_registry.provide(self._get_classification_group,
n_class_obj.ClassificationGroup.obj_name())
if self.rpc_notifications_required:
self.push_api = resources_rpc.ResourcesPushRpcApi()
def _init_classification_topics(self):
resources.register_resource_class(n_class_obj.ClassificationGroup)
resources.register_resource_class(n_class_obj.ClassificationBase)
for cls in class_obj.CLASS_MAP.values():
resources.register_resource_class(cls)
self.conn = n_rpc.Connection()
endpoints = [NeutronClassifierAdvertiserCallback(self)]
self.conn.create_consumer("q-classifier", endpoints,
fanout=False)
self.conn.consume_in_threads()
@staticmethod
def _get_classification(resource, classification_id, **kwargs):
context = kwargs.get('context')
if context is None:
LOG.warning(
'Received %(resource)s %(classification_id)s without context',
{'resource': resource, 'classification_id': classification_id})
return
c = n_class_obj.ClassificationBase(context, id=classification_id)
c_obj = class_obj.CLASS_MAP[c.c_type]
classification = c_obj.get_object(context, id=classification_id)
return classification
@staticmethod
def _get_classification_group(resource, cg_id, **kwargs):
context = kwargs.get('context')
if context is None:
LOG.warning(
'Received %(resource)s %(classification_id)s without context',
{'resource': resource, 'classification_id': cg_id})
return
cg = n_class_obj.ClassificationGroup.\
get_object(context,
id=cg_id)
return cg
@db_api.CONTEXT_READER
def _get_classification_group_mapping(self, context, cg_id):
with db_api.CONTEXT_READER.using(context):
mapped_db_cs = models._read_classifications(context, cg_id)
mapped_db_cgs = models._read_classification_groups(context, cg_id)
mapped_cs = [cs.id for cs in mapped_db_cs]
mapped_cgs = [cgs.id for cgs in mapped_db_cgs]
group_mappings = {'classifications': mapped_cs,
'classification_groups': mapped_cgs}
return group_mappings
def call(self, method_name, *args, **kwargs):
"""Helper method for calling a method across all extensions."""
if self.rpc_notifications_required:
context = kwargs.get('context') or args[0]
cls_obj = kwargs.get('classification') or args[1]
if method_name == nc_consts.CREATE_CLASS:
self.push_api.push(context, [cls_obj], rpc_events.CREATED)
elif method_name == nc_consts.DELETE_CLASS:
self.push_api.push(context, [cls_obj], rpc_events.DELETED)

View File

@ -0,0 +1,126 @@
# Copyright 2019 Intel Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from neutron.api.rpc.callbacks.consumer import registry
from neutron.api.rpc.callbacks import events
from neutron.api.rpc.callbacks import resources
from neutron.api.rpc.handlers import resources_rpc
from neutron.objects import classification as n_class_obj
from neutron_classifier.objects import classification as class_obj
from neutron_lib.agent import l2_extension
from neutron_lib import context
from neutron_lib import rpc as lib_rpc
from oslo_log import log as logging
import oslo_messaging
LOG = logging.getLogger(__name__)
class NeutronClassifierApi(object):
def __init__(self):
self.res_rpc = resources_rpc.ResourcesPullRpcApi()
self.context = context.get_admin_context_without_session()
target = oslo_messaging.Target(topic="q-classifier",
version='1.0')
self.client = lib_rpc.get_client(target)
def get_classification(self, class_id):
cctx = self.client.prepare()
return cctx.call(self.context, "get_classification", c_id=class_id)
def get_classification_group(self, class_grp_id):
return self.res_rpc.pull(self.context,
n_class_obj.ClassificationGroup.obj_name(),
class_grp_id)
def _get_classification_group_mapping(self, class_grp_id):
cctx = self.client.prepare()
return cctx.call(self.context, 'get_classification_group_mapping',
cg_id=class_grp_id)
class NeutronClassifierExtension(l2_extension.L2AgentExtension):
SUPPORTED_RESOURCE_TYPES = [
n_class_obj.ClassificationGroup.obj_name(),
n_class_obj.ClassificationBase.obj_name(),
class_obj.EthernetClassification.obj_name(),
class_obj.IPV4Classification.obj_name(),
class_obj.IPV6Classification.obj_name(),
class_obj.UDPClassification.obj_name(),
class_obj.TCPClassification.obj_name()]
def __init__(self):
super(NeutronClassifierExtension, self).__init__()
resources.register_resource_class(n_class_obj.ClassificationGroup)
resources.register_resource_class(n_class_obj.ClassificationBase)
self.class_type_list = []
for cls in class_obj.CLASS_MAP.values():
resources.register_resource_class(cls)
self.class_type_list.append(cls.obj_name())
def consume_api(self, agent_api):
self.agent_api = agent_api
agent_api.register_classification_api(NeutronClassifierApi())
def initialize(self, connection, driver_type):
super(NeutronClassifierExtension, self).initialize(connection,
driver_type)
self.resource_rpc = resources_rpc.ResourcesPullRpcApi()
self._register_rpc_consumers(connection)
def handle_port(self, context, port):
pass
def delete_port(self, context, port):
pass
def _register_rpc_consumers(self, connection):
'''Allows an extension to receive notifications.
The notification shows the updates made to
items of interest.
'''
endpoints = [resources_rpc.ResourcesPushRpcCallback()]
for resource_type in self.SUPPORTED_RESOURCE_TYPES:
registry.register(self.handle_notification, resource_type)
topic = resources_rpc.resource_type_versioned_topic(resource_type)
connection.create_consumer(topic, endpoints, fanout=True)
def handle_notification(self, context, resource_type,
class_objs, event_type):
'''Alerts the l2 extension agent.
Notifies if a classification or a classification
group has been made.
'''
if (event_type == events.CREATED
and resource_type ==
n_class_obj.ClassificationGroup.obj_name()):
for c_obj in class_objs:
self.agent_api.register_classification_group(
c_obj.id, c_obj)
if (event_type == events.CREATED and resource_type
in self.class_type_list):
for c_obj in class_objs:
self.agent_api.register_classification(c_obj.id,
c_obj)

View File

@ -14,15 +14,17 @@
from oslo_log import log as logging from oslo_log import log as logging
from neutron_lib.db import api as db_api
from neutron.objects import base as base_obj from neutron.objects import base as base_obj
from neutron.objects import classification as n_class_obj
from neutron_classifier.common import constants as nc_consts
from neutron_classifier.common import exceptions from neutron_classifier.common import exceptions
from neutron_classifier.common import validators from neutron_classifier.common import validators
from neutron_classifier.db import classification as c_db from neutron_classifier.db import classification as c_db
from neutron_classifier.extensions import classification from neutron_classifier.extensions import classification
from neutron_classifier.objects import classification as class_obj
from neutron_classifier.objects import classification_type as type_obj from neutron_classifier.objects import classification_type as type_obj
from neutron_classifier.objects import classifications as class_group from neutron_classifier.services.classification import advertiser
from neutron_lib.db import api as db_api
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -33,7 +35,7 @@ class ClassificationPlugin(classification.NeutronClassificationPluginBase,
def __init__(self): def __init__(self):
super(ClassificationPlugin, self).__init__() super(ClassificationPlugin, self).__init__()
self.driver_manager = None self.driver_manager = advertiser.ClassificationAdvertiser()
def create_classification(self, context, classification): def create_classification(self, context, classification):
details = self.break_out_headers(classification) details = self.break_out_headers(classification)
@ -44,24 +46,29 @@ class ClassificationPlugin(classification.NeutronClassificationPluginBase,
if key not in validators.type_validators[c_type].keys(): if key not in validators.type_validators[c_type].keys():
raise exceptions.InvalidClassificationDefintion() raise exceptions.InvalidClassificationDefintion()
cl = class_group.CLASS_MAP[c_type](context, **details) cl = class_obj.CLASS_MAP[c_type](context, **details)
with db_api.CONTEXT_WRITER.using(context): with db_api.CONTEXT_WRITER.using(context):
cl.create() cl.create()
db_dict = self.merge_header(cl) db_dict = self.merge_header(cl)
db_dict['id'] = cl['id'] db_dict['id'] = cl['id']
self.driver_manager.call(nc_consts.CREATE_CLASS, context, cl)
return db_dict return db_dict
def delete_classification(self, context, classification_id): def delete_classification(self, context, classification_id):
cl = class_group.ClassificationBase.get_object(context, cl = n_class_obj.ClassificationBase.\
id=classification_id) get_object(context,
cl_class = class_group.CLASS_MAP[cl.c_type] id=classification_id)
cl_class = class_obj.CLASS_MAP[cl.c_type]
classification = cl_class.get_object(context, id=classification_id) classification = cl_class.get_object(context, id=classification_id)
validators.check_valid_classifications(context, validators.check_valid_classifications(context,
[classification_id]) [classification_id])
with db_api.CONTEXT_WRITER.using(context): with db_api.CONTEXT_WRITER.using(context):
classification.delete() classification.delete()
self.driver_manager.call(nc_consts.DELETE_CLASS, context,
classification)
def update_classification(self, context, classification_id, def update_classification(self, context, classification_id,
fields_to_update): fields_to_update):
@ -71,9 +78,10 @@ class ClassificationPlugin(classification.NeutronClassificationPluginBase,
for key in field_keys: for key in field_keys:
if key not in valid_keys: if key not in valid_keys:
raise exceptions.InvalidUpdateRequest() raise exceptions.InvalidUpdateRequest()
cl = class_group.ClassificationBase.get_object(context, cl = n_class_obj.ClassificationBase.\
id=classification_id) get_object(context,
cl_class = class_group.CLASS_MAP[cl.c_type] id=classification_id)
cl_class = class_obj.CLASS_MAP[cl.c_type]
with db_api.CONTEXT_WRITER.using(context): with db_api.CONTEXT_WRITER.using(context):
classification = cl_class.update_object( classification = cl_class.update_object(
context, fields_to_update, id=classification_id) context, fields_to_update, id=classification_id)
@ -82,9 +90,10 @@ class ClassificationPlugin(classification.NeutronClassificationPluginBase,
return db_dict return db_dict
def get_classification(self, context, classification_id, fields=None): def get_classification(self, context, classification_id, fields=None):
cl = class_group.ClassificationBase.get_object(context, cl = n_class_obj.ClassificationBase.\
id=classification_id) get_object(context,
cl_class = class_group.CLASS_MAP[cl.c_type] id=classification_id)
cl_class = class_obj.CLASS_MAP[cl.c_type]
classification = cl_class.get_object(context, id=classification_id) classification = cl_class.get_object(context, id=classification_id)
clas = self.merge_header(classification) clas = self.merge_header(classification)
@ -95,8 +104,8 @@ class ClassificationPlugin(classification.NeutronClassificationPluginBase,
page_reverse=False): page_reverse=False):
c_type = filters['c_type'][0] c_type = filters['c_type'][0]
pager = base_obj.Pager(sorts, limit, page_reverse, marker) pager = base_obj.Pager(sorts, limit, page_reverse, marker)
cl = class_group.CLASS_MAP[c_type].get_objects(context, cl = class_obj.CLASS_MAP[c_type].get_objects(context,
_pager=pager) _pager=pager)
db_dict = self.merge_headers(cl) db_dict = self.merge_headers(cl)
return db_dict return db_dict
@ -106,7 +115,7 @@ class ClassificationPlugin(classification.NeutronClassificationPluginBase,
ret_list = [] ret_list = []
if not filters: if not filters:
filters = {} filters = {}
for key in class_group.CLASS_MAP.keys(): for key in class_obj.CLASS_MAP.keys():
types = {} types = {}
obj = type_obj.ClassificationType.get_object(key) obj = type_obj.ClassificationType.get_object(key)
types['type'] = obj.type types['type'] = obj.type

View File

@ -14,6 +14,7 @@
import copy import copy
from neutron.db import classification as cs_db
from neutron.tests.unit import testlib_api from neutron.tests.unit import testlib_api
from neutron_classifier.db import models from neutron_classifier.db import models
from neutron_lib import context from neutron_lib import context
@ -46,9 +47,9 @@ class TestDatabaseModels(testlib_api.MySQLTestCaseMixin,
standard_group['id'] = uuidutils.generate_uuid() standard_group['id'] = uuidutils.generate_uuid()
standard_class['name'] = "Test Class " + str(n) standard_class['name'] = "Test Class " + str(n)
standard_class['id'] = uuidutils.generate_uuid() standard_class['id'] = uuidutils.generate_uuid()
self._create_db_model(ctx, models.ClassificationGroup, self._create_db_model(ctx, cs_db.ClassificationGroup,
**standard_group) **standard_group)
self._create_db_model(ctx, models.ClassificationBase, self._create_db_model(ctx, cs_db.ClassificationBase,
**standard_class) **standard_class)
self.cg_list.append(copy.copy(standard_group)) self.cg_list.append(copy.copy(standard_group))
self.c_list.append(copy.copy(standard_class)) self.c_list.append(copy.copy(standard_class))
@ -78,10 +79,10 @@ class TestDatabaseModels(testlib_api.MySQLTestCaseMixin,
'stored_cg_id': self.cg_list[2]['id']}] 'stored_cg_id': self.cg_list[2]['id']}]
for n in range(4): for n in range(4):
self._create_db_model(ctx, models.CGToClassificationMapping, self._create_db_model(ctx, cs_db.CGToClassificationMapping,
**self.cg_to_c_list[n]) **self.cg_to_c_list[n])
self._create_db_model(ctx, self._create_db_model(ctx,
models.CGToClassificationGroupMapping, cs_db.CGToClassificationGroupMapping,
**self.cg_to_cg_list[n]) **self.cg_to_cg_list[n])
def _create_db_model(self, ctx, model, **kwargs): def _create_db_model(self, ctx, model, **kwargs):

View File

@ -2,6 +2,6 @@
# The order of packages is significant, because pip processes them in the order # The order of packages is significant, because pip processes them in the order
# of appearance. Changing the order has an impact on the overall integration # of appearance. Changing the order has an impact on the overall integration
# process, which may cause wedges in the gate later. # process, which may cause wedges in the gate later.
psutil>=3.2.2 # BSD psutil>=5.6.3 # BSD
psycopg2 psycopg2
PyMySQL>=0.7.6 # MIT License PyMySQL>=0.9.3 # MIT License

View File

@ -16,13 +16,15 @@ from oslo_utils import uuidutils
from neutron_lib.db import api as db_api from neutron_lib.db import api as db_api
from neutron.objects import classification as n_class_obj
from neutron.tests.unit import testlib_api from neutron.tests.unit import testlib_api
from neutron_classifier.common import exceptions from neutron_classifier.common import exceptions
from neutron_classifier.common import validators from neutron_classifier.common import validators
from neutron_classifier.db.classification import\ from neutron_classifier.db.classification import\
TrafficClassificationGroupPlugin as cg_plugin TrafficClassificationGroupPlugin as cg_plugin
from neutron_classifier.objects import classifications from neutron_classifier.objects import classification as class_obj
from neutron_classifier.services.classification.plugin import\ from neutron_classifier.services.classification.plugin import\
ClassificationPlugin as c_plugin ClassificationPlugin as c_plugin
from neutron_classifier.tests import objects_base as obj_base from neutron_classifier.tests import objects_base as obj_base
@ -57,8 +59,8 @@ class ClassificationGroupApiTest(testlib_api.MySQLTestCaseMixin,
def test_create_classification_group(self): def test_create_classification_group(self):
with db_api.CONTEXT_WRITER.using(self.ctx): with db_api.CONTEXT_WRITER.using(self.ctx):
tcp_class = classifications.TCPClassification tcp_class = class_obj.TCPClassification
ipv4_class = classifications.IPV4Classification ipv4_class = class_obj.IPV4Classification
cg2 = self._create_test_cg('Test Group 1') cg2 = self._create_test_cg('Test Group 1')
tcp = self._create_test_classification('tcp', tcp_class) tcp = self._create_test_classification('tcp', tcp_class)
ipv4 = self._create_test_classification('ipv4', ipv4_class) ipv4 = self._create_test_classification('ipv4', ipv4_class)
@ -72,11 +74,11 @@ class ClassificationGroupApiTest(testlib_api.MySQLTestCaseMixin,
}} }}
cg1 = self.test_plugin.create_classification_group(self.ctx, cg1 = self.test_plugin.create_classification_group(self.ctx,
cg_dict) cg_dict)
fetch_cg1 = classifications.ClassificationGroup.get_object( fetch_cg1 = n_class_obj.ClassificationGroup.get_object(
self.ctx, id=cg1['id']) self.ctx, id=cg1['id'])
mapped_cgs = classifications._get_mapped_classification_groups( mapped_cgs = class_obj._get_mapped_classification_groups(
self.ctx, fetch_cg1) self.ctx, fetch_cg1)
mapped_cs = classifications._get_mapped_classifications( mapped_cs = class_obj._get_mapped_classifications(
self.ctx, fetch_cg1) self.ctx, fetch_cg1)
mapped_classification_groups = [cg.id for cg in mapped_cgs] mapped_classification_groups = [cg.id for cg in mapped_cgs]
mapped_classifications = [c.id for c in mapped_cs] mapped_classifications = [c.id for c in mapped_cs]
@ -96,7 +98,7 @@ class ClassificationGroupApiTest(testlib_api.MySQLTestCaseMixin,
self.test_plugin.update_classification_group( self.test_plugin.update_classification_group(
self.ctx, cg1.id, self.ctx, cg1.id,
{'classification_group': {'name': 'Test Group updated'}}) {'classification_group': {'name': 'Test Group updated'}})
fetch_cg1 = classifications.ClassificationGroup.get_object( fetch_cg1 = n_class_obj.ClassificationGroup.get_object(
self.ctx, id=cg1['id']) self.ctx, id=cg1['id'])
self.assertRaises( self.assertRaises(
exceptions.InvalidUpdateRequest, exceptions.InvalidUpdateRequest,
@ -110,7 +112,7 @@ class ClassificationGroupApiTest(testlib_api.MySQLTestCaseMixin,
with db_api.CONTEXT_WRITER.using(self.ctx): with db_api.CONTEXT_WRITER.using(self.ctx):
cg1 = self._create_test_cg('Test Group 0') cg1 = self._create_test_cg('Test Group 0')
self.test_plugin.delete_classification_group(self.ctx, cg1.id) self.test_plugin.delete_classification_group(self.ctx, cg1.id)
fetch_cg1 = classifications.ClassificationGroup.get_object( fetch_cg1 = n_class_obj.ClassificationGroup.get_object(
self.ctx, id=cg1['id']) self.ctx, id=cg1['id'])
self.assertIsNone(fetch_cg1) self.assertIsNone(fetch_cg1)
@ -123,7 +125,7 @@ class ClassificationApiTest(testlib_api.MySQLTestCaseMixin,
self.test_clas_plugin = c_plugin() self.test_clas_plugin = c_plugin()
def test_create_classification(self): def test_create_classification(self):
attrs = self.get_random_attrs(classifications.EthernetClassification) attrs = self.get_random_attrs(class_obj.EthernetClassification)
c_type = 'ethernet' c_type = 'ethernet'
attrs['c_type'] = c_type attrs['c_type'] = c_type
attrs['definition'] = {} attrs['definition'] = {}
@ -133,7 +135,7 @@ class ClassificationApiTest(testlib_api.MySQLTestCaseMixin,
with db_api.CONTEXT_WRITER.using(self.ctx): with db_api.CONTEXT_WRITER.using(self.ctx):
c1 = self.test_clas_plugin.create_classification(self.ctx, c1 = self.test_clas_plugin.create_classification(self.ctx,
c_attrs) c_attrs)
fetch_c1 = classifications.EthernetClassification.get_object( fetch_c1 = class_obj.EthernetClassification.get_object(
self.ctx, id=c1['id'] self.ctx, id=c1['id']
) )
c_attrs['classification']['definition']['src_port'] = 'xyz' c_attrs['classification']['definition']['src_port'] = 'xyz'
@ -147,16 +149,16 @@ class ClassificationApiTest(testlib_api.MySQLTestCaseMixin,
self.assertEqual(y, fetch_c1[x]) self.assertEqual(y, fetch_c1[x])
def test_delete_classification(self): def test_delete_classification(self):
tcp_class = classifications.TCPClassification tcp_class = class_obj.TCPClassification
with db_api.CONTEXT_WRITER.using(self.ctx): with db_api.CONTEXT_WRITER.using(self.ctx):
tcp = self._create_test_classification('tcp', tcp_class) tcp = self._create_test_classification('tcp', tcp_class)
self.test_clas_plugin.delete_classification(self.ctx, tcp.id) self.test_clas_plugin.delete_classification(self.ctx, tcp.id)
fetch_tcp = classifications.TCPClassification.get_object( fetch_tcp = class_obj.TCPClassification.get_object(
self.ctx, id=tcp.id) self.ctx, id=tcp.id)
self.assertIsNone(fetch_tcp) self.assertIsNone(fetch_tcp)
def test_get_classification(self): def test_get_classification(self):
ipv4_class = classifications.IPV4Classification ipv4_class = class_obj.IPV4Classification
with db_api.CONTEXT_WRITER.using(self.ctx): with db_api.CONTEXT_WRITER.using(self.ctx):
ipv4 = self._create_test_classification('ipv4', ipv4_class) ipv4 = self._create_test_classification('ipv4', ipv4_class)
fetch_ipv4 = self.test_clas_plugin.get_classification(self.ctx, fetch_ipv4 = self.test_clas_plugin.get_classification(self.ctx,
@ -166,9 +168,9 @@ class ClassificationApiTest(testlib_api.MySQLTestCaseMixin,
def test_get_classifications(self): def test_get_classifications(self):
with db_api.CONTEXT_WRITER.using(self.ctx): with db_api.CONTEXT_WRITER.using(self.ctx):
c1 = self._create_test_classification( c1 = self._create_test_classification(
'ipv6', classifications.IPV6Classification) 'ipv6', class_obj.IPV6Classification)
c2 = self._create_test_classification( c2 = self._create_test_classification(
'udp', classifications.UDPClassification) 'udp', class_obj.UDPClassification)
fetch_cs_udp = self.test_clas_plugin.get_classifications( fetch_cs_udp = self.test_clas_plugin.get_classifications(
self.ctx, filters={'c_type': ['udp']}) self.ctx, filters={'c_type': ['udp']})
fetch_cs_ipv6 = self.test_clas_plugin.get_classifications( fetch_cs_ipv6 = self.test_clas_plugin.get_classifications(
@ -180,11 +182,11 @@ class ClassificationApiTest(testlib_api.MySQLTestCaseMixin,
def test_update_classification(self): def test_update_classification(self):
c1 = self._create_test_classification( c1 = self._create_test_classification(
'ethernet', classifications.EthernetClassification) 'ethernet', class_obj.EthernetClassification)
updated_name = 'Test Updated Classification' updated_name = 'Test Updated Classification'
with db_api.CONTEXT_WRITER.using(self.ctx): with db_api.CONTEXT_WRITER.using(self.ctx):
self.test_clas_plugin.update_classification( self.test_clas_plugin.update_classification(
self.ctx, c1.id, {'classification': {'name': updated_name}}) self.ctx, c1.id, {'classification': {'name': updated_name}})
fetch_c1 = classifications.EthernetClassification.get_object( fetch_c1 = class_obj.EthernetClassification.get_object(
self.ctx, id=c1.id) self.ctx, id=c1.id)
self.assertEqual(fetch_c1.name, updated_name) self.assertEqual(fetch_c1.name, updated_name)

View File

@ -17,9 +17,10 @@ import oslo_versionedobjects
from neutron_lib import context from neutron_lib import context
from neutron.objects import classification as n_class_obj
from neutron.tests.unit.objects import test_base from neutron.tests.unit.objects import test_base
from neutron_classifier.objects import classifications from neutron_classifier.objects import classification as class_obj
from neutron_classifier.tests import tools from neutron_classifier.tests import tools
@ -27,8 +28,8 @@ class _CCFObjectsTestCommon(object):
# TODO(ndahiwade): this represents classifications containing Enum fields, # TODO(ndahiwade): this represents classifications containing Enum fields,
# will need to be reworked if more classifications are added here later. # will need to be reworked if more classifications are added here later.
_Enum_classifications = [classifications.IPV4Classification, _Enum_classifications = [class_obj.IPV4Classification,
classifications.IPV6Classification] class_obj.IPV6Classification]
_Enumfield = oslo_versionedobjects.fields.EnumField _Enumfield = oslo_versionedobjects.fields.EnumField
ctx = context.get_admin_context() ctx = context.get_admin_context()
@ -49,7 +50,7 @@ class _CCFObjectsTestCommon(object):
'project_id': uuidutils.generate_uuid(), 'project_id': uuidutils.generate_uuid(),
'shared': False, 'shared': False,
'operator': 'AND'} 'operator': 'AND'}
cg = classifications.ClassificationGroup(self.ctx, **attrs) cg = n_class_obj.ClassificationGroup(self.ctx, **attrs)
cg.create() cg.create()
return cg return cg
@ -65,15 +66,15 @@ class _CCFObjectsTestCommon(object):
def _create_test_cg_cg_mapping(self, cg1, cg2): def _create_test_cg_cg_mapping(self, cg1, cg2):
attrs = {'container_cg_id': cg1, attrs = {'container_cg_id': cg1,
'stored_cg_id': cg2} 'stored_cg_id': cg2}
cg_m_cg = classifications.CGToClassificationGroupMapping(self.ctx, cg_m_cg = n_class_obj.CGToClassificationGroupMapping(self.ctx,
**attrs) **attrs)
cg_m_cg.create() cg_m_cg.create()
return cg_m_cg return cg_m_cg
def _create_test_cg_c_mapping(self, cg, c): def _create_test_cg_c_mapping(self, cg, c):
attrs = {'container_cg_id': cg, attrs = {'container_cg_id': cg,
'stored_classification_id': c} 'stored_classification_id': c}
cg_m_c = classifications.CGToClassificationMapping(self.ctx, cg_m_c = n_class_obj.CGToClassificationMapping(self.ctx,
**attrs) **attrs)
cg_m_c.create() cg_m_c.create()
return cg_m_c return cg_m_c

View File

@ -14,9 +14,8 @@
import mock import mock
from neutron.objects import base as base_obj from neutron.objects import base as base_obj
from neutron.objects import classification as n_class_obj
from neutron_classifier.db import classification as cg_api from neutron_classifier.db import classification as cg_api
from neutron_classifier.objects import classifications
from neutron_classifier.tests import base from neutron_classifier.tests import base
from neutron_lib import context from neutron_lib import context
from oslo_utils import uuidutils from oslo_utils import uuidutils
@ -72,10 +71,10 @@ class TestClassificationGroupPlugin(base.BaseClassificationTestCase):
} }
return self.test_cg return self.test_cg
@mock.patch.object(classifications.CGToClassificationGroupMapping, @mock.patch.object(n_class_obj.CGToClassificationGroupMapping,
'create') 'create')
@mock.patch.object(classifications.CGToClassificationMapping, 'create') @mock.patch.object(n_class_obj.CGToClassificationMapping, 'create')
@mock.patch.object(classifications.ClassificationGroup, 'create') @mock.patch.object(n_class_obj.ClassificationGroup, 'create')
def test_create_classification_group(self, mock_cg_create, def test_create_classification_group(self, mock_cg_create,
mock_cg_c_mapping_create, mock_cg_c_mapping_create,
mock_cg_cg_mapping_create): mock_cg_cg_mapping_create):
@ -105,7 +104,7 @@ class TestClassificationGroupPlugin(base.BaseClassificationTestCase):
mock_manager.create_cg_cg.assert_called_once() mock_manager.create_cg_cg.assert_called_once()
self.assertEqual(mock_manager.create_cg_c.call_count, c_len) self.assertEqual(mock_manager.create_cg_c.call_count, c_len)
@mock.patch.object(classifications.ClassificationGroup, 'get_object') @mock.patch.object(n_class_obj.ClassificationGroup, 'get_object')
@mock.patch('neutron_classifier.common.validators.' @mock.patch('neutron_classifier.common.validators.'
'check_can_delete_classification_group') 'check_can_delete_classification_group')
def test_delete_classification_group(self, mock_valid_delete, def test_delete_classification_group(self, mock_valid_delete,
@ -134,11 +133,11 @@ class TestClassificationGroupPlugin(base.BaseClassificationTestCase):
mock_manager.mock_calls.index(mock_cg_get_call) < mock_manager.mock_calls.index(mock_cg_get_call) <
mock_manager.mock_calls.index(mock_cg_delete_call)) mock_manager.mock_calls.index(mock_cg_delete_call))
@mock.patch('neutron_classifier.objects.classifications.' @mock.patch('neutron_classifier.objects.classification.'
'_get_mapped_classification_groups') '_get_mapped_classification_groups')
@mock.patch('neutron_classifier.objects.classifications.' @mock.patch('neutron_classifier.objects.classification.'
'_get_mapped_classifications') '_get_mapped_classifications')
@mock.patch.object(classifications.ClassificationGroup, 'get_object') @mock.patch.object(n_class_obj.ClassificationGroup, 'get_object')
@mock.patch('neutron_classifier.db.classification.' @mock.patch('neutron_classifier.db.classification.'
'TrafficClassificationGroupPlugin._make_db_dicts') 'TrafficClassificationGroupPlugin._make_db_dicts')
def test_get_classification_group(self, mock_db_dicts, mock_cg_get, def test_get_classification_group(self, mock_db_dicts, mock_cg_get,
@ -177,7 +176,7 @@ class TestClassificationGroupPlugin(base.BaseClassificationTestCase):
mock_mapped_cgs.assert_called_once() mock_mapped_cgs.assert_called_once()
@mock.patch.object(base_obj, 'Pager') @mock.patch.object(base_obj, 'Pager')
@mock.patch.object(classifications.ClassificationGroup, 'get_objects') @mock.patch.object(n_class_obj.ClassificationGroup, 'get_objects')
@mock.patch.object(cg_api.TrafficClassificationGroupPlugin, @mock.patch.object(cg_api.TrafficClassificationGroupPlugin,
'_make_db_dicts') '_make_db_dicts')
def test_get_classification_groups(self, mock_db_dicts, mock_cgs_get, def test_get_classification_groups(self, mock_db_dicts, mock_cgs_get,
@ -193,8 +192,8 @@ class TestClassificationGroupPlugin(base.BaseClassificationTestCase):
test_cg1 = test_cg1['classification_group'] test_cg1 = test_cg1['classification_group']
test_cg2 = test_cg2['classification_group'] test_cg2 = test_cg2['classification_group']
cg1 = classifications.ClassificationGroup(self.ctxt, **test_cg1) cg1 = n_class_obj.ClassificationGroup(self.ctxt, **test_cg1)
cg2 = classifications.ClassificationGroup(self.ctxt, **test_cg2) cg2 = n_class_obj.ClassificationGroup(self.ctxt, **test_cg2)
cg_list = [self.cg_plugin._make_db_dict(cg) for cg in [cg1, cg2]] cg_list = [self.cg_plugin._make_db_dict(cg) for cg in [cg1, cg2]]
mock_manager.get_cgs.return_value = cg_list mock_manager.get_cgs.return_value = cg_list
@ -206,7 +205,7 @@ class TestClassificationGroupPlugin(base.BaseClassificationTestCase):
mock_manager.db_dicts.assert_called_once() mock_manager.db_dicts.assert_called_once()
self.assertEqual(len(mock_manager.mock_calls), 3) self.assertEqual(len(mock_manager.mock_calls), 3)
@mock.patch.object(classifications.ClassificationGroup, 'update_object') @mock.patch.object(n_class_obj.ClassificationGroup, 'update_object')
def test_update_classification_group(self, mock_cg_update): def test_update_classification_group(self, mock_cg_update):
mock_manager = mock.Mock() mock_manager = mock.Mock()
mock_manager.attach_mock(mock_cg_update, 'cg_update') mock_manager.attach_mock(mock_cg_update, 'cg_update')
@ -215,7 +214,7 @@ class TestClassificationGroupPlugin(base.BaseClassificationTestCase):
test_cg = self._generate_test_classification_group('Test Group') test_cg = self._generate_test_classification_group('Test Group')
test_cg = test_cg['classification_group'] test_cg = test_cg['classification_group']
cg = classifications.ClassificationGroup(self.ctxt, **test_cg) cg = n_class_obj.ClassificationGroup(self.ctxt, **test_cg)
updated_fields = {'classification_group': updated_fields = {'classification_group':
{'name': 'Test Group Updated', {'name': 'Test Group Updated',

View File

@ -22,7 +22,6 @@ from oslo_versionedobjects import base as obj_base
from oslo_versionedobjects import fixture from oslo_versionedobjects import fixture
from neutron import objects as n_obj from neutron import objects as n_obj
from neutron_classifier import objects from neutron_classifier import objects
from neutron_classifier.tests import base as test_base from neutron_classifier.tests import base as test_base
@ -33,7 +32,6 @@ from neutron_classifier.tests import base as test_base
# This list also includes VersionedObjects from Neutron that are registered # This list also includes VersionedObjects from Neutron that are registered
# through dependencies. # through dependencies.
object_data = { object_data = {
'ClassificationGroup': '1.0-e621ff663f76bb494072872222f5fe72',
'CGToClassificationGroupMapping': '1.0-8ebed0ef1035bcc4b307da1bbdc6be64', 'CGToClassificationGroupMapping': '1.0-8ebed0ef1035bcc4b307da1bbdc6be64',
'CGToClassificationMapping': '1.0-fe5942adbe82301a38b67bdce484efb1', 'CGToClassificationMapping': '1.0-fe5942adbe82301a38b67bdce484efb1',
'EthernetClassification': '1.0-267f03162a6e011197b663ee34e6cb0b', 'EthernetClassification': '1.0-267f03162a6e011197b663ee34e6cb0b',

View File

@ -14,7 +14,8 @@
import oslo_versionedobjects import oslo_versionedobjects
from neutron_classifier.objects import classifications from neutron.objects import classification as n_class_obj
from neutron_classifier.objects import classification as class_obj
from neutron_classifier.tests import objects_base as obj_base from neutron_classifier.tests import objects_base as obj_base
from neutron_classifier.tests import tools from neutron_classifier.tests import tools
@ -33,18 +34,18 @@ class ClassificationGroupTest(test_base.BaseDbObjectTestCase,
# we are adding it here for our use rather than adding in neutron. # we are adding it here for our use rather than adding in neutron.
test_base.FIELD_TYPE_VALUE_GENERATOR_MAP[ test_base.FIELD_TYPE_VALUE_GENERATOR_MAP[
oslo_versionedobjects.fields.EnumField] = tools.get_random_operator oslo_versionedobjects.fields.EnumField] = tools.get_random_operator
_test_class = classifications.ClassificationGroup _test_class = n_class_obj.ClassificationGroup
def test_get_object(self): def test_get_object(self):
cg = self._create_test_cg('Test Group 0') cg = self._create_test_cg('Test Group 0')
fetch_cg = classifications.ClassificationGroup.get_object( fetch_cg = n_class_obj.ClassificationGroup.get_object(
self.ctx, id=cg.id) self.ctx, id=cg.id)
self.assertEqual(cg, fetch_cg) self.assertEqual(cg, fetch_cg)
def test_get_objects(self): def test_get_objects(self):
cg1 = self._create_test_cg('Test Group 1') cg1 = self._create_test_cg('Test Group 1')
cg2 = self._create_test_cg('Test Group 2') cg2 = self._create_test_cg('Test Group 2')
cgs = classifications.ClassificationGroup.get_objects(self.ctx) cgs = n_class_obj.ClassificationGroup.get_objects(self.ctx)
self.assertIn(cg1, cgs) self.assertIn(cg1, cgs)
self.assertIn(cg2, cgs) self.assertIn(cg2, cgs)
@ -55,7 +56,7 @@ class ClassificationGroupTest(test_base.BaseDbObjectTestCase,
class UDPClassificationTest(testlib_api.SqlTestCase, class UDPClassificationTest(testlib_api.SqlTestCase,
obj_base._CCFObjectsTestCommon): obj_base._CCFObjectsTestCommon):
test_class = classifications.UDPClassification test_class = class_obj.UDPClassification
def test_get_object(self): def test_get_object(self):
udp = self._create_test_classification('udp', self.test_class) udp = self._create_test_classification('udp', self.test_class)
@ -73,7 +74,7 @@ class UDPClassificationTest(testlib_api.SqlTestCase,
class IPV4ClassificationTest(testlib_api.SqlTestCase, class IPV4ClassificationTest(testlib_api.SqlTestCase,
obj_base._CCFObjectsTestCommon): obj_base._CCFObjectsTestCommon):
test_class = classifications.IPV4Classification test_class = class_obj.IPV4Classification
def test_get_object(self): def test_get_object(self):
ipv4 = self._create_test_classification('ipv4', self.test_class) ipv4 = self._create_test_classification('ipv4', self.test_class)
@ -91,7 +92,7 @@ class IPV4ClassificationTest(testlib_api.SqlTestCase,
class IPV6ClassificationTest(testlib_api.SqlTestCase, class IPV6ClassificationTest(testlib_api.SqlTestCase,
obj_base._CCFObjectsTestCommon): obj_base._CCFObjectsTestCommon):
test_class = classifications.IPV6Classification test_class = class_obj.IPV6Classification
def test_get_object(self): def test_get_object(self):
ipv6 = self._create_test_classification('ipv6', self.test_class) ipv6 = self._create_test_classification('ipv6', self.test_class)
@ -109,7 +110,7 @@ class IPV6ClassificationTest(testlib_api.SqlTestCase,
class TCPClassificationTest(testlib_api.SqlTestCase, class TCPClassificationTest(testlib_api.SqlTestCase,
obj_base._CCFObjectsTestCommon): obj_base._CCFObjectsTestCommon):
test_class = classifications.TCPClassification test_class = class_obj.TCPClassification
def test_get_object(self): def test_get_object(self):
tcp = self._create_test_classification('tcp', self.test_class) tcp = self._create_test_classification('tcp', self.test_class)
@ -127,7 +128,7 @@ class TCPClassificationTest(testlib_api.SqlTestCase,
class EthernetClassificationTest(testlib_api.SqlTestCase, class EthernetClassificationTest(testlib_api.SqlTestCase,
obj_base._CCFObjectsTestCommon): obj_base._CCFObjectsTestCommon):
test_class = classifications.EthernetClassification test_class = class_obj.EthernetClassification
def test_get_object(self): def test_get_object(self):
ethernet = self._create_test_classification('ethernet', ethernet = self._create_test_classification('ethernet',
@ -153,11 +154,11 @@ class CGToClassificationGroupMappingTest(testlib_api.SqlTestCase,
cg1 = self._create_test_cg('Test Group 0') cg1 = self._create_test_cg('Test Group 0')
cg2 = self._create_test_cg('Test Group 1') cg2 = self._create_test_cg('Test Group 1')
cg_m_cg = self._create_test_cg_cg_mapping(cg1.id, cg2.id) cg_m_cg = self._create_test_cg_cg_mapping(cg1.id, cg2.id)
fetch_cg = classifications.ClassificationGroup.get_object( fetch_cg = n_class_obj.ClassificationGroup.get_object(
self.ctx, id=cg1.id) self.ctx, id=cg1.id)
mapped_cg = classifications._get_mapped_classification_groups( mapped_cg = class_obj._get_mapped_classification_groups(
self.ctx, fetch_cg) self.ctx, fetch_cg)
fetch_cg_m_cg = classifications.CGToClassificationGroupMapping.\ fetch_cg_m_cg = n_class_obj.CGToClassificationGroupMapping.\
get_object(self.ctx, id=cg_m_cg.container_cg_id) get_object(self.ctx, id=cg_m_cg.container_cg_id)
self.assertEqual(mapped_cg[0], cg2) self.assertEqual(mapped_cg[0], cg2)
self.assertEqual(cg_m_cg, fetch_cg_m_cg) self.assertEqual(cg_m_cg, fetch_cg_m_cg)
@ -171,9 +172,9 @@ class CGToClassificationGroupMappingTest(testlib_api.SqlTestCase,
cgs = [cg2, cg3, cg4] cgs = [cg2, cg3, cg4]
for cg in cgs: for cg in cgs:
self._create_test_cg_cg_mapping(cg1.id, cg.id) self._create_test_cg_cg_mapping(cg1.id, cg.id)
fetch_cg1 = classifications.ClassificationGroup.get_object( fetch_cg1 = n_class_obj.ClassificationGroup.get_object(
self.ctx, id=cg1.id) self.ctx, id=cg1.id)
mapped_cgs = classifications._get_mapped_classification_groups( mapped_cgs = class_obj._get_mapped_classification_groups(
self.ctx, fetch_cg1) self.ctx, fetch_cg1)
for cg in cgs: for cg in cgs:
self.assertIn(cg, mapped_cgs) self.assertIn(cg, mapped_cgs)
@ -188,15 +189,15 @@ class CGToClassificationMappingTest(testlib_api.SqlTestCase,
with db_api.CONTEXT_WRITER.using(self.ctx): with db_api.CONTEXT_WRITER.using(self.ctx):
cg = self._create_test_cg('Test Group') cg = self._create_test_cg('Test Group')
cl_ = self._create_test_classification( cl_ = self._create_test_classification(
'udp', classifications.UDPClassification) 'udp', class_obj.UDPClassification)
cg_m_c = self._create_test_cg_c_mapping(cg.id, cl_.id) cg_m_c = self._create_test_cg_c_mapping(cg.id, cl_.id)
fetch_c = classifications.UDPClassification.get_object( fetch_c = class_obj.UDPClassification.get_object(
self.ctx, id=cl_.id) self.ctx, id=cl_.id)
fetch_cg = classifications.ClassificationGroup.get_object( fetch_cg = n_class_obj.ClassificationGroup.get_object(
self.ctx, id=cg.id) self.ctx, id=cg.id)
mapped_cs = classifications._get_mapped_classifications( mapped_cs = class_obj._get_mapped_classifications(
self.ctx, fetch_cg) self.ctx, fetch_cg)
fetch_cg_m_c = classifications.CGToClassificationMapping. \ fetch_cg_m_c = n_class_obj.CGToClassificationMapping. \
get_object(self.ctx, id=cg_m_c.container_cg_id) get_object(self.ctx, id=cg_m_c.container_cg_id)
self.assertIn(fetch_c, mapped_cs) self.assertIn(fetch_c, mapped_cs)
self.assertEqual(cg_m_c, fetch_cg_m_c) self.assertEqual(cg_m_c, fetch_cg_m_c)
@ -205,17 +206,17 @@ class CGToClassificationMappingTest(testlib_api.SqlTestCase,
with db_api.CONTEXT_WRITER.using(self.ctx): with db_api.CONTEXT_WRITER.using(self.ctx):
cg = self._create_test_cg('Test Group') cg = self._create_test_cg('Test Group')
c1 = self._create_test_classification( c1 = self._create_test_classification(
'tcp', classifications.TCPClassification) 'tcp', class_obj.TCPClassification)
c2 = self._create_test_classification( c2 = self._create_test_classification(
'udp', classifications.UDPClassification) 'udp', class_obj.UDPClassification)
c3 = self._create_test_classification( c3 = self._create_test_classification(
'ethernet', classifications.EthernetClassification) 'ethernet', class_obj.EthernetClassification)
cs = [c1, c2, c3] cs = [c1, c2, c3]
for c in cs: for c in cs:
self._create_test_cg_c_mapping(cg.id, c.id) self._create_test_cg_c_mapping(cg.id, c.id)
fetch_cg = classifications.ClassificationGroup.get_object( fetch_cg = n_class_obj.ClassificationGroup.get_object(
self.ctx, id=cg.id) self.ctx, id=cg.id)
mapped_cs = classifications._get_mapped_classifications( mapped_cs = class_obj._get_mapped_classifications(
self.ctx, fetch_cg) self.ctx, fetch_cg)
for c in cs: for c in cs:
self.assertIn(c, mapped_cs) self.assertIn(c, mapped_cs)

View File

@ -0,0 +1,150 @@
# Copyright 2019 Intel Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import mock
from neutron.api.rpc import handlers
from neutron.tests.unit import testlib_api
from neutron_classifier.services.classification import advertiser
from oslo_utils import uuidutils
class TestAdvertiser(testlib_api.SqlTestCase):
def setUp(self):
super(TestAdvertiser, self).setUp()
self.mock_context = mock.Mock()
self.mock_id = (uuidutils.generate_uuid())
mock.patch.object(advertiser, 'resources_rpc').start()
mock.patch.object(advertiser, 'n_rpc').start()
mock.patch.object(advertiser, 'n_class_obj').start()
mock.patch.object(advertiser, 'db_api').start()
self.advertiser = advertiser.ClassificationAdvertiser()
mock.patch('neutron.objects.db.api.get_object').start()
mock.patch('neutron_lib.db.api.CONTEXT_READER.using').start()
@mock.patch.object(advertiser, 'class_obj')
@mock.patch.object(advertiser, 'n_class_obj')
def test_get_classification(self, mock_n_class_obj,
mock_cs_group):
mock_obj = mock.Mock()
mock_cls = mock.Mock()
mock_cls.get_object.return_value = mock_obj
mock_base = mock.Mock()
mock_cs_group.CLASS_MAP = mock.Mock()
mock_n_class_obj.ClassificationBase = mock.Mock(
return_value=mock_base)
mock_cs_group.CLASS_MAP.__getitem__ = mock.\
Mock(return_value=mock_cls)
return_obj = self.advertiser.\
_get_classification('', self.mock_id,
context=self.mock_context)
self.assertEqual(return_obj, mock_obj)
@mock.patch.object(advertiser, 'n_class_obj')
def test_get_classification_group(self, mock_n_class_obj):
test_cg = mock.Mock()
mock_n_class_obj.ClassificationGroup.get_object.\
return_value = test_cg
return_obj = self.advertiser.\
_get_classification_group('', self.mock_id,
context=self.mock_context)
self.assertEqual(return_obj, test_cg)
@mock.patch.object(advertiser.models, '_read_classifications')
@mock.patch.object(advertiser.models, '_read_classification_groups')
@mock.patch.object(advertiser, 'db_api')
def test_get_classification_group_mapping(self, mock_db_api,
mock_read_cls_grp,
mock_read_cls):
mock_db_api.start()
mock_read_cls_grp.start()
mock_read_cls.start()
ret_obj = None
mock_cls = mock.Mock()
mock_cls_grp = mock.Mock()
mock_cls.id = uuidutils.generate_uuid()
mock_cls_grp.id = uuidutils.generate_uuid()
print("mock_cls_grp.id", mock_cls_grp.id)
print("mock_cls.id", mock_cls.id)
mock_read_cls_grp.return_value = [mock_cls_grp]
mock_read_cls.return_value = [mock_cls]
mock_mapped_cs = [mock_cs.id for mock_cs in [mock_cls]]
mock_mapped_cgs = [mock_cgs.id for mock_cgs in [mock_cls_grp]]
print('mock_mapped_cs', mock_mapped_cs)
print('mock_mapped_cgs', mock_mapped_cgs)
with mock_db_api.using(self.mock_context):
ret_obj = self.advertiser._get_classification_group_mapping(
self.mock_context, self.mock_id)
self.assertEqual(
{'classifications': mock_mapped_cs, 'classification_groups':
mock_mapped_cgs}, ret_obj)
mock_read_cls_grp.assert_called_once_with(self.mock_context,
self.mock_id)
mock_read_cls.assert_called_once_with(self.mock_context,
self.mock_id)
@mock.patch.object(handlers, 'resources_rpc')
@mock.patch.object(advertiser, 'nc_consts')
@mock.patch.object(advertiser, 'rpc_events')
def test_call_creates_classes(self, mock_rpc_events, mock_nc_consts,
mock_resources_rpc):
mock_resources_rpc.start()
mock_nc_consts.start()
mock_rpc_events.start()
self.mock_classification = mock.Mock()
self.advertiser.push_api = mock.Mock()
self.advertiser.push_api.push = mock.Mock()
self.advertiser.call(mock_nc_consts.CREATE_CLASS,
context=self.mock_context,
classification=self.mock_classification)
self.advertiser.push_api.push.\
assert_called_with(self.mock_context, [self.mock_classification],
mock_rpc_events.CREATED)
@mock.patch.object(handlers, 'resources_rpc')
@mock.patch.object(advertiser, 'nc_consts')
@mock.patch.object(advertiser, 'rpc_events')
def test_call_deletes_classes(self, mock_rpc_events, mock_nc_consts,
mock_resources_rpc):
mock_resources_rpc.start()
mock_nc_consts.start()
mock_rpc_events.start()
self.mock_classification = mock.Mock()
self.advertiser.push_api = mock.Mock()
self.advertiser.push_api.push = mock.Mock()
self.advertiser.call(mock_nc_consts.DELETE_CLASS,
context=self.mock_context,
classification=self.mock_classification)
self.advertiser.push_api.push.\
assert_called_with(self.mock_context,
[self.mock_classification],
mock_rpc_events.DELETED)
def tearDown(self):
super(TestAdvertiser, self).tearDown()

View File

@ -0,0 +1,97 @@
# Copyright 2019 Intel Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import unittest
import mock
from neutron.api.rpc.callbacks import events
from neutron.objects import classification as n_class_obj
from neutron_classifier.objects import classification as class_obj
from neutron_classifier.services.classification import extension
from oslo_utils import uuidutils
class TestClassifierExtension(unittest.TestCase):
def setUp(self):
super(TestClassifierExtension, self).setUp()
self.mock_context = mock.Mock()
mock.patch.object(extension, 'resources').start()
self.mock_id = uuidutils.generate_uuid()
self.extension = extension.NeutronClassifierExtension()
self.extension.agent_api = mock.Mock()
mock_rtvt = mock.patch('neutron.api.rpc.handlers.resources_rpc'
'.resource_type_versioned_topic')
mock_r = mock.patch('neutron.api.rpc.callbacks'
'.consumer.registry.register')
mock_rtvt.start()
mock_r.start()
def test_register_rpc_consumers(self):
mock_connection = mock.Mock()
mock_consumer = mock.MagicMock()
mock_connection.create_consumer = mock_consumer
test_supported_resource_types = [
n_class_obj.ClassificationGroup.obj_name(),
n_class_obj.ClassificationBase.obj_name(),
class_obj.EthernetClassification.obj_name(),
class_obj.IPV4Classification.obj_name(),
class_obj.IPV6Classification.obj_name(),
class_obj.UDPClassification.obj_name(),
class_obj.TCPClassification.obj_name()
]
self.extension._register_rpc_consumers(mock_connection)
self.assertEqual(mock_consumer.call_count,
len(test_supported_resource_types))
def test_handle_notification_ignores_events(self):
self.extension.agent_api.register_classification = mock.Mock()
for event_type in set(events.VALID) - {events.CREATED}:
self.extension.handle_notification(mock.Mock(), '',
object(), event_type)
self.assertFalse(self.extension.agent_api.
register_classification.called)
def test_handle_notification_passes_events_classification(self):
self.extension.agent_api.register_classification = mock.Mock()
class_obj = mock.Mock()
self.extension.handle_notification(mock.Mock(), 'IPV4Classification',
[class_obj], events.CREATED)
self.extension.agent_api.register_classification. \
assert_called_once()
self.assertFalse(self.extension.agent_api.
register_classification_group.called)
def test_handle_notification_passes_events_classification_group(self):
self.extension.agent_api.register_classification_group = mock.Mock()
class_obj = mock.Mock()
self.extension.handle_notification(mock.Mock(), 'ClassificationGroup',
[class_obj], events.CREATED)
self.extension.agent_api.register_classification_group. \
assert_called_once()
self.assertFalse(self.extension.agent_api.
register_classification.called)
def tearDown(self):
super(TestClassifierExtension, self).tearDown()
pass

View File

@ -14,7 +14,8 @@
import mock import mock
from neutron.objects import base as base_obj from neutron.objects import base as base_obj
from neutron_classifier.objects import classifications as class_group from neutron.objects import classification as n_class_obj
from neutron_classifier.objects import classification as class_obj
from neutron_classifier.services.classification import plugin from neutron_classifier.services.classification import plugin
from neutron_classifier.tests import base from neutron_classifier.tests import base
from neutron_lib import context from neutron_lib import context
@ -31,6 +32,8 @@ class TestPlugin(base.BaseClassificationTestCase):
mock.patch('neutron.objects.db.api.update_object').start() mock.patch('neutron.objects.db.api.update_object').start()
mock.patch('neutron.objects.db.api.delete_object').start() mock.patch('neutron.objects.db.api.delete_object').start()
mock.patch('neutron.objects.db.api.get_object').start() mock.patch('neutron.objects.db.api.get_object').start()
mock.patch('neutron_classifier.services.classification.advertiser'
'.ClassificationAdvertiser').start()
self.cl_plugin = plugin.ClassificationPlugin() self.cl_plugin = plugin.ClassificationPlugin()
@ -38,7 +41,7 @@ class TestPlugin(base.BaseClassificationTestCase):
mock.patch.object(self.ctxt.session, 'refresh').start() mock.patch.object(self.ctxt.session, 'refresh').start()
mock.patch.object(self.ctxt.session, 'expunge').start() mock.patch.object(self.ctxt.session, 'expunge').start()
mock.patch('neutron_classifier.objects.classifications').start() mock.patch('neutron_classifier.objects.classification').start()
self._generate_test_classifications() self._generate_test_classifications()
@ -106,8 +109,8 @@ class TestPlugin(base.BaseClassificationTestCase):
self.assertEqual(self.test_classification['classification'], self.assertEqual(self.test_classification['classification'],
cl) cl)
@mock.patch.object(class_group.EthernetClassification, 'create') @mock.patch.object(class_obj.EthernetClassification, 'create')
@mock.patch.object(class_group.EthernetClassification, 'id', @mock.patch.object(class_obj.EthernetClassification, 'id',
return_value=uuidutils.generate_uuid()) return_value=uuidutils.generate_uuid())
def test_create_classification(self, mock_ethernet_id, def test_create_classification(self, mock_ethernet_id,
mock_ethernet_create): mock_ethernet_create):
@ -123,15 +126,15 @@ class TestPlugin(base.BaseClassificationTestCase):
self.ctxt, self.test_classification) self.ctxt, self.test_classification)
expected_val = self.test_classification['classification'] expected_val = self.test_classification['classification']
expected_val['id'] = class_group.EthernetClassification.id expected_val['id'] = class_obj.EthernetClassification.id
self.assertEqual(expected_val, val) self.assertEqual(expected_val, val)
mock_manager.create.assert_called_once() mock_manager.create.assert_called_once()
@mock.patch.object(plugin.ClassificationPlugin, 'merge_header') @mock.patch.object(plugin.ClassificationPlugin, 'merge_header')
@mock.patch.object(class_group.ClassificationBase, 'get_object') @mock.patch.object(n_class_obj.ClassificationBase, 'get_object')
@mock.patch.object(class_group.EthernetClassification, 'update_object') @mock.patch.object(class_obj.EthernetClassification, 'update_object')
@mock.patch.object(class_group.EthernetClassification, 'id', @mock.patch.object(class_obj.EthernetClassification, 'id',
return_value=uuidutils.generate_uuid()) return_value=uuidutils.generate_uuid())
def test_update_classification(self, mock_id, mock_ethernet_update, def test_update_classification(self, mock_id, mock_ethernet_update,
mock_class_get, mock_merge): mock_class_get, mock_merge):
@ -143,7 +146,7 @@ class TestPlugin(base.BaseClassificationTestCase):
mock_manager.reset_mock() mock_manager.reset_mock()
mock_manager.start() mock_manager.start()
class_obj = class_group.EthernetClassification( c_obj = class_obj.EthernetClassification(
self.ctxt, **self.test_classification_broken_headers) self.ctxt, **self.test_classification_broken_headers)
ethernet_classification_update = {'classification': { ethernet_classification_update = {'classification': {
@ -152,29 +155,29 @@ class TestPlugin(base.BaseClassificationTestCase):
mock_manager.get_classification().c_type = 'ethernet' mock_manager.get_classification().c_type = 'ethernet'
self.cl_plugin.update_classification( self.cl_plugin.update_classification(
self.ctxt, class_obj.id, self.ctxt, c_obj.id,
ethernet_classification_update) ethernet_classification_update)
classification_update_mock_call = mock.call.update( classification_update_mock_call = mock.call.update(
self.ctxt, self.ctxt,
{'description': 'Test Ethernet Classification Version 2', {'description': 'Test Ethernet Classification Version 2',
'name': 'test_ethernet_classification Version 2'}, 'name': 'test_ethernet_classification Version 2'},
id=class_obj.id) id=c_obj.id)
self.assertIn(classification_update_mock_call, mock_manager.mock_calls) self.assertIn(classification_update_mock_call, mock_manager.mock_calls)
self.assertEqual(mock_manager.get_classification.call_count, 2) self.assertEqual(mock_manager.get_classification.call_count, 2)
@mock.patch.object(class_group.ClassificationBase, 'get_object') @mock.patch.object(n_class_obj.ClassificationBase, 'get_object')
@mock.patch.object(class_group.EthernetClassification, 'get_object') @mock.patch.object(class_obj.EthernetClassification, 'get_object')
def test_delete_classification(self, mock_ethernet_get, mock_base_get): def test_delete_classification(self, mock_ethernet_get, mock_base_get):
mock_manager = mock.Mock() mock_manager = mock.Mock()
mock_manager.attach_mock(mock_base_get, 'get_object') mock_manager.attach_mock(mock_base_get, 'get_object')
mock_manager.attach_mock(mock_ethernet_get, 'get_object') mock_manager.attach_mock(mock_ethernet_get, 'get_object')
eth_class_obj = class_group.EthernetClassification( eth_class_obj = class_obj.EthernetClassification(
self.ctxt, **self.test_classification_broken_headers) self.ctxt, **self.test_classification_broken_headers)
eth_class_obj.delete = mock.Mock() eth_class_obj.delete = mock.Mock()
base_class_obj = class_group.ClassificationBase( base_class_obj = n_class_obj.ClassificationBase(
self.ctxt, **self.test_classification_broken_headers) self.ctxt, **self.test_classification_broken_headers)
mock_base_get.return_value = base_class_obj mock_base_get.return_value = base_class_obj
@ -193,8 +196,8 @@ class TestPlugin(base.BaseClassificationTestCase):
mock_manager.mock_calls) mock_manager.mock_calls)
self.assertTrue(eth_class_obj.delete.assert_called_once) self.assertTrue(eth_class_obj.delete.assert_called_once)
@mock.patch.object(class_group.ClassificationBase, 'get_object') @mock.patch.object(n_class_obj.ClassificationBase, 'get_object')
@mock.patch.object(class_group.EthernetClassification, 'get_object') @mock.patch.object(class_obj.EthernetClassification, 'get_object')
def test_get_classification(self, mock_ethernet_get, def test_get_classification(self, mock_ethernet_get,
mock_base_get): mock_base_get):
mock_manager = mock.Mock() mock_manager = mock.Mock()
@ -206,9 +209,9 @@ class TestPlugin(base.BaseClassificationTestCase):
definition = eth_classification.pop('definition') definition = eth_classification.pop('definition')
base_class_obj = class_group.ClassificationBase( base_class_obj = n_class_obj.ClassificationBase(
self.ctxt, **eth_classification) self.ctxt, **eth_classification)
eth_class_obj = class_group.EthernetClassification( eth_class_obj = class_obj.EthernetClassification(
self.ctxt, **self.test_classification_broken_headers) self.ctxt, **self.test_classification_broken_headers)
mock_base_get.return_value = base_class_obj mock_base_get.return_value = base_class_obj
@ -227,8 +230,8 @@ class TestPlugin(base.BaseClassificationTestCase):
mock_manager.mock_calls) mock_manager.mock_calls)
self.assertTrue(eth_classification, value) self.assertTrue(eth_classification, value)
@mock.patch.object(class_group.ClassificationBase, 'get_objects') @mock.patch.object(n_class_obj.ClassificationBase, 'get_objects')
@mock.patch.object(class_group.EthernetClassification, 'get_objects') @mock.patch.object(class_obj.EthernetClassification, 'get_objects')
@mock.patch.object(base_obj, 'Pager') @mock.patch.object(base_obj, 'Pager')
def test_get_classifications(self, mock_pager, mock_ethernet_get, def test_get_classifications(self, mock_pager, mock_ethernet_get,
mock_base_get): mock_base_get):
@ -242,13 +245,13 @@ class TestPlugin(base.BaseClassificationTestCase):
definition = eth_cl_1.pop('definition') definition = eth_cl_1.pop('definition')
definition_2 = eth_cl_2.pop('definition') definition_2 = eth_cl_2.pop('definition')
base_class_obj_1 = class_group.ClassificationBase( base_class_obj_1 = n_class_obj.ClassificationBase(
self.ctxt, **eth_cl_1) self.ctxt, **eth_cl_1)
base_class_obj_2 = class_group.ClassificationBase( base_class_obj_2 = n_class_obj.ClassificationBase(
self.ctxt, **eth_cl_2) self.ctxt, **eth_cl_2)
eth_class_obj_1 = class_group.EthernetClassification( eth_class_obj_1 = class_obj.EthernetClassification(
self.ctxt, **self.test_classification_broken_headers) self.ctxt, **self.test_classification_broken_headers)
eth_class_obj_2 = class_group.EthernetClassification( eth_class_obj_2 = class_obj.EthernetClassification(
self.ctxt, **self.test_classification_2_broken_headers) self.ctxt, **self.test_classification_2_broken_headers)
base_list = [base_class_obj_1, base_class_obj_2] base_list = [base_class_obj_1, base_class_obj_2]

View File

@ -59,6 +59,8 @@ openstack.neutronclient.v2 =
network classification group update = neutron_classifier.cli.openstack_cli.classification_group:UpdateClassificationGroup network classification group update = neutron_classifier.cli.openstack_cli.classification_group:UpdateClassificationGroup
neutron.db.alembic_migrations = neutron.db.alembic_migrations =
neutron-classifier = neutron_classifier.db.migration:alembic_migrations neutron-classifier = neutron_classifier.db.migration:alembic_migrations
neutron.agent.l2.extensions =
neutron_classifier = neutron_classifier.services.classification.extension:NeutronClassifierExtension
[build_sphinx] [build_sphinx]
source-dir = doc/source source-dir = doc/source