Default net assignment

Adds default network assignment to Quark, along with yanking out
repoze.tm2 for compatibility while reinstating the standard
sqlalchemy-style transaction sessions to eliminate the race condition in
mass-assigning IP addresses under load.

Also fixes a typo with ipam_reuse_after which was previously, and
incorrectly, a boolean config value.
This commit is contained in:
Matt Dietz and John Yolo Perkins
2013-09-24 16:20:59 +00:00
committed by Matt Dietz
parent da5f7b4425
commit c758f21558
16 changed files with 625 additions and 551 deletions

View File

@@ -25,7 +25,7 @@ quark_opts = [
help=_('The client to use to talk to the backend')), help=_('The client to use to talk to the backend')),
cfg.StrOpt('ipam_driver', default='quark.ipam.QuarkIpam', cfg.StrOpt('ipam_driver', default='quark.ipam.QuarkIpam',
help=_('IPAM Implementation to use')), help=_('IPAM Implementation to use')),
cfg.BoolOpt('ipam_reuse_after', default=7200, cfg.IntOpt('ipam_reuse_after', default=7200,
help=_("Time in seconds til IP and MAC reuse" help=_("Time in seconds til IP and MAC reuse"
"after deallocation.")), "after deallocation.")),
cfg.StrOpt("strategy_driver", cfg.StrOpt("strategy_driver",

View File

@@ -116,7 +116,7 @@ def _model_query(context, model, filters, fields=None):
# This works even when a non-shared, other-tenant owned network is passed # This works even when a non-shared, other-tenant owned network is passed
# in because the authZ checks that happen in Neutron above us yank it back # in because the authZ checks that happen in Neutron above us yank it back
# out of the result set. # out of the result set.
if "tenant_id" not in filters and not context.is_admin: if not filters and not context.is_admin:
filters["tenant_id"] = [context.tenant_id] filters["tenant_id"] = [context.tenant_id]
if filters.get("tenant_id"): if filters.get("tenant_id"):
@@ -215,7 +215,7 @@ def ip_address_create(context, **address_dict):
@scoped @scoped
def ip_address_find(context, **filters): def ip_address_find(context, lock_mode=False, **filters):
query = context.session.query(models.IPAddress) query = context.session.query(models.IPAddress)
ip_shared = filters.pop("shared", None) ip_shared = filters.pop("shared", None)
@@ -223,6 +223,8 @@ def ip_address_find(context, **filters):
cnt = sql_func.count(models.port_ip_association_table.c.port_id) cnt = sql_func.count(models.port_ip_association_table.c.port_id)
stmt = context.session.query(models.IPAddress, stmt = context.session.query(models.IPAddress,
cnt.label("ports_count")) cnt.label("ports_count"))
if lock_mode:
stmt = stmt.with_lockmode("update")
stmt = stmt.outerjoin(models.port_ip_association_table) stmt = stmt.outerjoin(models.port_ip_association_table)
stmt = stmt.group_by(models.IPAddress).subquery() stmt = stmt.group_by(models.IPAddress).subquery()
@@ -239,13 +241,14 @@ def ip_address_find(context, **filters):
if filters.get("device_id"): if filters.get("device_id"):
model_filters.append(models.IPAddress.ports.any( model_filters.append(models.IPAddress.ports.any(
models.Port.device_id.in_(filters["device_id"]))) models.Port.device_id.in_(filters["device_id"])))
return query.filter(*model_filters) return query.filter(*model_filters)
@scoped @scoped
def mac_address_find(context, **filters): def mac_address_find(context, lock_mode=False, **filters):
query = context.session.query(models.MacAddress) query = context.session.query(models.MacAddress)
if lock_mode:
query.with_lockmode("update")
model_filters = _model_query(context, models.MacAddress, filters) model_filters = _model_query(context, models.MacAddress, filters)
return query.filter(*model_filters) return query.filter(*model_filters)
@@ -253,7 +256,7 @@ def mac_address_find(context, **filters):
def mac_address_range_find_allocation_counts(context, address=None): def mac_address_range_find_allocation_counts(context, address=None):
query = context.session.query(models.MacAddressRange, query = context.session.query(models.MacAddressRange,
sql_func.count(models.MacAddress.address). sql_func.count(models.MacAddress.address).
label("count")) label("count")).with_lockmode("update")
query = query.outerjoin(models.MacAddress) query = query.outerjoin(models.MacAddress)
query = query.group_by(models.MacAddressRange) query = query.group_by(models.MacAddressRange)
query = query.order_by("count DESC") query = query.order_by("count DESC")
@@ -362,7 +365,7 @@ def network_delete(context, network):
def subnet_find_allocation_counts(context, net_id, **filters): def subnet_find_allocation_counts(context, net_id, **filters):
query = context.session.query(models.Subnet, query = context.session.query(models.Subnet,
sql_func.count(models.IPAddress.address). sql_func.count(models.IPAddress.address).
label("count")) label("count")).with_lockmode('update')
query = query.outerjoin(models.Subnet.allocated_ips) query = query.outerjoin(models.Subnet.allocated_ips)
query = query.group_by(models.Subnet) query = query.group_by(models.Subnet)
query = query.order_by("count DESC") query = query.order_by("count DESC")

View File

@@ -24,11 +24,14 @@ from neutron.openstack.common import log as logging
from neutron.openstack.common.notifier import api as notifier_api from neutron.openstack.common.notifier import api as notifier_api
from neutron.openstack.common import timeutils from neutron.openstack.common import timeutils
from oslo.config import cfg
from quark.db import api as db_api from quark.db import api as db_api
from quark.db import models from quark.db import models
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
CONF = cfg.CONF
class QuarkIpam(object): class QuarkIpam(object):
@@ -60,21 +63,24 @@ class QuarkIpam(object):
if mac_address: if mac_address:
mac_address = netaddr.EUI(mac_address).value mac_address = netaddr.EUI(mac_address).value
with context.session.begin(subtransactions=True):
deallocated_mac = db_api.mac_address_find( deallocated_mac = db_api.mac_address_find(
context, reuse_after=reuse_after, scope=db_api.ONE, context, lock_mode=True, reuse_after=reuse_after,
address=mac_address) scope=db_api.ONE, address=mac_address)
if deallocated_mac: if deallocated_mac:
return db_api.mac_address_update( return db_api.mac_address_update(
context, deallocated_mac, deallocated=False, context, deallocated_mac, deallocated=False,
deallocated_at=None) deallocated_at=None)
with context.session.begin(subtransactions=True):
ranges = db_api.mac_address_range_find_allocation_counts( ranges = db_api.mac_address_range_find_allocation_counts(
context, address=mac_address) context, address=mac_address)
for result in ranges: for result in ranges:
rng, addr_count = result rng, addr_count = result
if rng["last_address"] - rng["first_address"] <= addr_count: last = rng["last_address"]
first = rng["first_address"]
if last - first <= addr_count:
continue continue
next_address = None next_address = None
if mac_address: if mac_address:
next_address = mac_address next_address = mac_address
@@ -87,7 +93,8 @@ class QuarkIpam(object):
context, tenant_id=context.tenant_id, context, tenant_id=context.tenant_id,
scope=db_api.ONE, address=next_address) scope=db_api.ONE, address=next_address)
address = db_api.mac_address_create(context, address=next_address, address = db_api.mac_address_create(
context, address=next_address,
mac_address_range_id=rng["id"]) mac_address_range_id=rng["id"])
return address return address
@@ -99,16 +106,23 @@ class QuarkIpam(object):
if ip_address: if ip_address:
ip_address = netaddr.IPAddress(ip_address) ip_address = netaddr.IPAddress(ip_address)
with context.session.begin(subtransactions=True):
address = db_api.ip_address_find( address = db_api.ip_address_find(
elevated, network_id=net_id, reuse_after=reuse_after, elevated, network_id=net_id, reuse_after=reuse_after,
deallocated=True, scope=db_api.ONE, ip_address=ip_address) deallocated=True, scope=db_api.ONE, ip_address=ip_address,
if address: lock_mode=True)
return db_api.ip_address_update(
elevated, address, deallocated=False, deallocated_at=None)
if address:
updated_address = db_api.ip_address_update(
elevated, address, deallocated=False,
deallocated_at=None)
return updated_address
with context.session.begin(subtransactions=True):
subnet = self._choose_available_subnet( subnet = self._choose_available_subnet(
elevated, net_id, ip_address=ip_address, version=version) elevated, net_id, ip_address=ip_address, version=version)
ip_policy_rules = models.IPPolicy.get_ip_policy_rule_set(subnet) ip_policy_rules = models.IPPolicy.get_ip_policy_rule_set(
subnet)
# Creating this IP for the first time # Creating this IP for the first time
next_ip = None next_ip = None
@@ -118,7 +132,8 @@ class QuarkIpam(object):
elevated, network_id=net_id, ip_address=next_ip, elevated, network_id=net_id, ip_address=next_ip,
tenant_id=elevated.tenant_id, scope=db_api.ONE) tenant_id=elevated.tenant_id, scope=db_api.ONE)
if address: if address:
raise exceptions.IpAddressGenerationFailure(net_id=net_id) raise exceptions.IpAddressGenerationFailure(
net_id=net_id)
else: else:
address = True address = True
while address: while address:
@@ -132,7 +147,7 @@ class QuarkIpam(object):
address = db_api.ip_address_find( address = db_api.ip_address_find(
elevated, network_id=net_id, ip_address=next_ip, elevated, network_id=net_id, ip_address=next_ip,
tenant_id=elevated.tenant_id, scope=db_api.ONE) tenant_id=elevated.tenant_id, scope=db_api.ONE)
context.session.add(subnet)
address = db_api.ip_address_create( address = db_api.ip_address_create(
elevated, address=next_ip, subnet_id=subnet["id"], elevated, address=next_ip, subnet_id=subnet["id"],
version=subnet["ip_version"], network_id=net_id) version=subnet["ip_version"], network_id=net_id)
@@ -167,13 +182,15 @@ class QuarkIpam(object):
payload) payload)
def deallocate_ip_address(self, context, port, **kwargs): def deallocate_ip_address(self, context, port, **kwargs):
with context.session.begin(subtransactions=True):
for addr in port["ip_addresses"]: for addr in port["ip_addresses"]:
# Note: only deallocate ip if this is the only port mapped to it # Note: only deallocate ip if this is the only port mapped
if len(addr["ports"]) == 1: if len(addr["ports"]) == 1:
self._deallocate_ip_address(context, addr) self._deallocate_ip_address(context, addr)
port["ip_addresses"] = [] port["ip_addresses"] = []
def deallocate_mac_address(self, context, address): def deallocate_mac_address(self, context, address):
with context.session.begin(subtransactions=True):
mac = db_api.mac_address_find(context, address=address, mac = db_api.mac_address_find(context, address=address,
scope=db_api.ONE) scope=db_api.ONE)
if not mac: if not mac:

View File

@@ -18,13 +18,9 @@ v2 Neutron Plug-in API Quark Implementation
""" """
from oslo.config import cfg from oslo.config import cfg
from sqlalchemy.orm import sessionmaker, scoped_session
from zope import sqlalchemy as zsa
from neutron.db import api as neutron_db_api from neutron.db import api as neutron_db_api
from neutron.extensions import securitygroup as sg_ext from neutron.extensions import securitygroup as sg_ext
from neutron import neutron_plugin_base_v2 from neutron import neutron_plugin_base_v2
from neutron.openstack.common.db.sqlalchemy import session as neutron_session
from neutron import quota from neutron import quota
from quark.api import extensions from quark.api import extensions
@@ -79,15 +75,8 @@ class Plugin(neutron_plugin_base_v2.NeutronPluginBaseV2,
"subnets_quark", "provider", "subnets_quark", "provider",
"ip_policies", "quotas"] "ip_policies", "quotas"]
def _initDBMaker(self):
# This needs to be called after _ENGINE is configured
session_maker = sessionmaker(bind=neutron_session._ENGINE,
extension=zsa.ZopeTransactionExtension())
neutron_session._MAKER = scoped_session(session_maker)
def __init__(self): def __init__(self):
neutron_db_api.configure_db() neutron_db_api.configure_db()
self._initDBMaker()
neutron_db_api.register_models(base=models.BASEV2) neutron_db_api.register_models(base=models.BASEV2)
def get_mac_address_range(self, context, id, fields=None): def get_mac_address_range(self, context, id, fields=None):

View File

@@ -60,6 +60,7 @@ def create_ip_address(context, ip_address):
raise exceptions.BadRequest( raise exceptions.BadRequest(
resource="ip_addresses", resource="ip_addresses",
msg="network_id is required if device_ids are supplied.") msg="network_id is required if device_ids are supplied.")
with context.session.begin():
if network_id and device_ids: if network_id and device_ids:
for device_id in device_ids: for device_id in device_ids:
port = db_api.port_find( port = db_api.port_find(
@@ -95,6 +96,7 @@ def update_ip_address(context, id, ip_address):
LOG.info("update_ip_address %s for tenant %s" % LOG.info("update_ip_address %s for tenant %s" %
(id, context.tenant_id)) (id, context.tenant_id))
with context.session.begin():
address = db_api.ip_address_find( address = db_api.ip_address_find(
context, id=id, tenant_id=context.tenant_id, scope=db_api.ONE) context, id=id, tenant_id=context.tenant_id, scope=db_api.ONE)

View File

@@ -42,6 +42,7 @@ def create_ip_policy(context, ip_policy):
resource="ip_policy", resource="ip_policy",
msg="network_ids or subnet_ids not specified") msg="network_ids or subnet_ids not specified")
with context.session.begin():
models = [] models = []
if subnet_ids: if subnet_ids:
subnets = db_api.subnet_find( subnets = db_api.subnet_find(
@@ -85,6 +86,7 @@ def update_ip_policy(context, id, ip_policy):
ipp = ip_policy["ip_policy"] ipp = ip_policy["ip_policy"]
with context.session.begin():
ipp_db = db_api.ip_policy_find(context, id=id, scope=db_api.ONE) ipp_db = db_api.ip_policy_find(context, id=id, scope=db_api.ONE)
if not ipp_db: if not ipp_db:
raise quark_exceptions.IPPolicyNotFound(id=id) raise quark_exceptions.IPPolicyNotFound(id=id)
@@ -105,7 +107,8 @@ def update_ip_policy(context, id, ip_policy):
if network_ids: if network_ids:
for network in ipp_db["networks"]: for network in ipp_db["networks"]:
network["ip_policy"] = None network["ip_policy"] = None
nets = db_api.network_find(context, id=network_ids, scope=db_api.ALL) nets = db_api.network_find(context, id=network_ids,
scope=db_api.ALL)
if len(nets) != len(network_ids): if len(nets) != len(network_ids):
raise exceptions.NetworkNotFound(net_id=network_ids) raise exceptions.NetworkNotFound(net_id=network_ids)
models.extend(nets) models.extend(nets)
@@ -122,6 +125,7 @@ def update_ip_policy(context, id, ip_policy):
def delete_ip_policy(context, id): def delete_ip_policy(context, id):
LOG.info("delete_ip_policy %s for tenant %s" % (id, context.tenant_id)) LOG.info("delete_ip_policy %s for tenant %s" % (id, context.tenant_id))
with context.session.begin():
ipp = db_api.ip_policy_find(context, id=id, scope=db_api.ONE) ipp = db_api.ip_policy_find(context, id=id, scope=db_api.ONE)
if not ipp: if not ipp:
raise quark_exceptions.IPPolicyNotFound(id=id) raise quark_exceptions.IPPolicyNotFound(id=id)

View File

@@ -82,6 +82,7 @@ def create_mac_address_range(context, mac_range):
LOG.info("create_mac_address_range for tenant %s" % context.tenant_id) LOG.info("create_mac_address_range for tenant %s" % context.tenant_id)
cidr = mac_range["mac_address_range"]["cidr"] cidr = mac_range["mac_address_range"]["cidr"]
cidr, first_address, last_address = _to_mac_range(cidr) cidr, first_address, last_address = _to_mac_range(cidr)
with context.session.begin():
new_range = db_api.mac_address_range_create( new_range = db_api.mac_address_range_create(
context, cidr=cidr, first_address=first_address, context, cidr=cidr, first_address=first_address,
last_address=last_address, next_auto_assign_mac=first_address) last_address=last_address, next_auto_assign_mac=first_address)
@@ -103,6 +104,7 @@ def delete_mac_address_range(context, id):
""" """
LOG.info("delete_mac_address_range %s for tenant %s" % LOG.info("delete_mac_address_range %s for tenant %s" %
(id, context.tenant_id)) (id, context.tenant_id))
with context.session.begin():
mar = db_api.mac_address_range_find(context, id=id, scope=db_api.ONE) mar = db_api.mac_address_range_find(context, id=id, scope=db_api.ONE)
if not mar: if not mar:
raise quark_exceptions.MacAddressRangeNotFound( raise quark_exceptions.MacAddressRangeNotFound(

View File

@@ -60,6 +60,7 @@ def create_network(context, network):
""" """
LOG.info("create_network for tenant %s" % context.tenant_id) LOG.info("create_network for tenant %s" % context.tenant_id)
with context.session.begin():
# Generate a uuid that we're going to hand to the backend and db # Generate a uuid that we're going to hand to the backend and db
net_uuid = uuidutils.generate_uuid() net_uuid = uuidutils.generate_uuid()
@@ -77,9 +78,9 @@ def create_network(context, network):
#TODO(dietz or perkins): Allow this to be overridden later with CLI #TODO(dietz or perkins): Allow this to be overridden later with CLI
default_net_type = CONF.QUARK.default_network_type default_net_type = CONF.QUARK.default_network_type
net_driver = registry.DRIVER_REGISTRY.get_driver(default_net_type) net_driver = registry.DRIVER_REGISTRY.get_driver(default_net_type)
net_driver.create_network(context, net_attrs["name"], network_id=net_uuid, net_driver.create_network(context, net_attrs["name"],
phys_type=pnet_type, phys_net=phys_net, network_id=net_uuid, phys_type=pnet_type,
segment_id=seg_id) phys_net=phys_net, segment_id=seg_id)
subs = net_attrs.pop("subnets", []) subs = net_attrs.pop("subnets", [])
@@ -115,6 +116,7 @@ def update_network(context, id, network):
""" """
LOG.info("update_network %s for tenant %s" % LOG.info("update_network %s for tenant %s" %
(id, context.tenant_id)) (id, context.tenant_id))
with context.session.begin():
net = db_api.network_find(context, id=id, scope=db_api.ONE) net = db_api.network_find(context, id=id, scope=db_api.ONE)
if not net: if not net:
raise exceptions.NetworkNotFound(net_id=id) raise exceptions.NetworkNotFound(net_id=id)
@@ -198,6 +200,7 @@ def delete_network(context, id):
: param id: UUID representing the network to delete. : param id: UUID representing the network to delete.
""" """
LOG.info("delete_network %s for tenant %s" % (id, context.tenant_id)) LOG.info("delete_network %s for tenant %s" % (id, context.tenant_id))
with context.session.begin():
net = db_api.network_find(context, id=id, scope=db_api.ONE) net = db_api.network_find(context, id=id, scope=db_api.ONE)
if not net: if not net:
raise exceptions.NetworkNotFound(net_id=id) raise exceptions.NetworkNotFound(net_id=id)

View File

@@ -52,9 +52,10 @@ def create_port(context, port):
net_id = port_attrs["network_id"] net_id = port_attrs["network_id"]
addresses = [] addresses = []
with context.session.begin():
port_id = uuidutils.generate_uuid() port_id = uuidutils.generate_uuid()
net = db_api.network_find(context, id=net_id, shared=True, net = db_api.network_find(context, id=net_id,
segment_id=segment_id, scope=db_api.ONE) segment_id=segment_id, scope=db_api.ONE)
if not net: if not net:
# Maybe it's a tenant network # Maybe it's a tenant network
@@ -92,7 +93,8 @@ def create_port(context, port):
'ip_address': address.get('address_readable', '')} 'ip_address': address.get('address_readable', '')}
for address in addresses] for address in addresses]
net_driver = registry.DRIVER_REGISTRY.get_driver(net["network_plugin"]) net_driver = registry.DRIVER_REGISTRY.get_driver(net["network_plugin"])
backend_port = net_driver.create_port(context, net["id"], port_id=port_id, backend_port = net_driver.create_port(context, net["id"],
port_id=port_id,
security_groups=group_ids, security_groups=group_ids,
allowed_pairs=address_pairs) allowed_pairs=address_pairs)
@@ -121,6 +123,7 @@ def update_port(context, id, port):
neutron/api/v2/attributes.py. neutron/api/v2/attributes.py.
""" """
LOG.info("update_port %s for tenant %s" % (id, context.tenant_id)) LOG.info("update_port %s for tenant %s" % (id, context.tenant_id))
with context.session.begin():
port_db = db_api.port_find(context, id=id, scope=db_api.ONE) port_db = db_api.port_find(context, id=id, scope=db_api.ONE)
if not port_db: if not port_db:
raise exceptions.PortNotFound(port_id=id) raise exceptions.PortNotFound(port_id=id)
@@ -170,6 +173,7 @@ def post_update_port(context, id, port):
raise exceptions.BadRequest(resource="ports", raise exceptions.BadRequest(resource="ports",
msg="Port body required") msg="Port body required")
with context.session.begin():
port_db = db_api.port_find(context, id=id, scope=db_api.ONE) port_db = db_api.port_find(context, id=id, scope=db_api.ONE)
if not port_db: if not port_db:
raise exceptions.PortNotFound(port_id=id, net_id="") raise exceptions.PortNotFound(port_id=id, net_id="")
@@ -194,7 +198,8 @@ def post_update_port(context, id, port):
if not address: if not address:
address = ipam_driver.allocate_ip_address( address = ipam_driver.allocate_ip_address(
context, port_db["network_id"], id, context, port_db["network_id"], id,
CONF.QUARK.ipam_reuse_after, ip_address=ip_address) CONF.QUARK.ipam_reuse_after,
ip_address=ip_address)
else: else:
address = ipam_driver.allocate_ip_address( address = ipam_driver.allocate_ip_address(
context, port_db["network_id"], id, context, port_db["network_id"], id,
@@ -296,6 +301,7 @@ def delete_port(context, id):
if not port: if not port:
raise exceptions.PortNotFound(net_id=id) raise exceptions.PortNotFound(net_id=id)
with context.session.begin():
backend_key = port["backend_key"] backend_key = port["backend_key"]
mac_address = netaddr.EUI(port["mac_address"]).value mac_address = netaddr.EUI(port["mac_address"]).value
ipam_driver.deallocate_mac_address(context, mac_address) ipam_driver.deallocate_mac_address(context, mac_address)
@@ -317,6 +323,7 @@ def disassociate_port(context, id, ip_address_id):
""" """
LOG.info("disassociate_port %s for tenant %s ip_address_id %s" % LOG.info("disassociate_port %s for tenant %s ip_address_id %s" %
(id, context.tenant_id, ip_address_id)) (id, context.tenant_id, ip_address_id))
with context.session.begin():
port = db_api.port_find(context, id=id, ip_address_id=[ip_address_id], port = db_api.port_find(context, id=id, ip_address_id=[ip_address_id],
scope=db_api.ONE) scope=db_api.ONE)

View File

@@ -49,6 +49,7 @@ def create_route(context, route):
LOG.info("create_route for tenant %s" % context.tenant_id) LOG.info("create_route for tenant %s" % context.tenant_id)
route = route["route"] route = route["route"]
subnet_id = route["subnet_id"] subnet_id = route["subnet_id"]
with context.session.begin():
subnet = db_api.subnet_find(context, id=subnet_id, scope=db_api.ONE) subnet = db_api.subnet_find(context, id=subnet_id, scope=db_api.ONE)
if not subnet: if not subnet:
raise exceptions.SubnetNotFound(subnet_id=subnet_id) raise exceptions.SubnetNotFound(subnet_id=subnet_id)
@@ -74,6 +75,7 @@ def delete_route(context, id):
# admin and only filter on tenant if they aren't. Correct # admin and only filter on tenant if they aren't. Correct
# for all the above later # for all the above later
LOG.info("delete_route %s for tenant %s" % (id, context.tenant_id)) LOG.info("delete_route %s for tenant %s" % (id, context.tenant_id))
with context.session.begin():
route = db_api.route_find(context, id, scope=db_api.ONE) route = db_api.route_find(context, id, scope=db_api.ONE)
if not route: if not route:
raise quark_exceptions.RouteNotFound(route_id=id) raise quark_exceptions.RouteNotFound(route_id=id)

View File

@@ -76,6 +76,7 @@ def create_security_group(context, security_group, net_driver):
raise sg_ext.SecurityGroupDefaultAlreadyExists() raise sg_ext.SecurityGroupDefaultAlreadyExists()
group_id = uuidutils.generate_uuid() group_id = uuidutils.generate_uuid()
with context.session.begin():
net_driver.create_security_group( net_driver.create_security_group(
context, context,
group_name, group_name,
@@ -121,6 +122,8 @@ def _create_default_security_group(context, net_driver):
def create_security_group_rule(context, security_group_rule, net_driver): def create_security_group_rule(context, security_group_rule, net_driver):
LOG.info("create_security_group for tenant %s" % LOG.info("create_security_group for tenant %s" %
(context.tenant_id)) (context.tenant_id))
with context.session.begin():
rule = _validate_security_group_rule( rule = _validate_security_group_rule(
context, security_group_rule["security_group_rule"]) context, security_group_rule["security_group_rule"])
rule["id"] = uuidutils.generate_uuid() rule["id"] = uuidutils.generate_uuid()
@@ -145,6 +148,7 @@ def delete_security_group(context, id, net_driver):
LOG.info("delete_security_group %s for tenant %s" % LOG.info("delete_security_group %s for tenant %s" %
(id, context.tenant_id)) (id, context.tenant_id))
with context.session.begin():
group = db_api.security_group_find(context, id=id, scope=db_api.ONE) group = db_api.security_group_find(context, id=id, scope=db_api.ONE)
#TODO(anyone): name and ports are lazy-loaded. Could be good op later #TODO(anyone): name and ports are lazy-loaded. Could be good op later
@@ -161,6 +165,7 @@ def delete_security_group(context, id, net_driver):
def delete_security_group_rule(context, id, net_driver): def delete_security_group_rule(context, id, net_driver):
LOG.info("delete_security_group %s for tenant %s" % LOG.info("delete_security_group %s for tenant %s" %
(id, context.tenant_id)) (id, context.tenant_id))
with context.session.begin():
rule = db_api.security_group_rule_find(context, id=id, rule = db_api.security_group_rule_find(context, id=id,
scope=db_api.ONE) scope=db_api.ONE)
if not rule: if not rule:
@@ -219,6 +224,7 @@ def update_security_group(context, id, security_group, net_driver):
if id == DEFAULT_SG_UUID: if id == DEFAULT_SG_UUID:
raise sg_ext.SecurityGroupCannotUpdateDefault() raise sg_ext.SecurityGroupCannotUpdateDefault()
new_group = security_group["security_group"] new_group = security_group["security_group"]
with context.session.begin():
group = db_api.security_group_find(context, id=id, scope=db_api.ONE) group = db_api.security_group_find(context, id=id, scope=db_api.ONE)
net_driver.update_security_group(context, id, **new_group) net_driver.update_security_group(context, id, **new_group)

View File

@@ -85,6 +85,7 @@ def create_subnet(context, subnet):
LOG.info("create_subnet for tenant %s" % context.tenant_id) LOG.info("create_subnet for tenant %s" % context.tenant_id)
net_id = subnet["subnet"]["network_id"] net_id = subnet["subnet"]["network_id"]
with context.session.begin():
net = db_api.network_find(context, id=net_id, scope=db_api.ONE) net = db_api.network_find(context, id=net_id, scope=db_api.ONE)
if not net: if not net:
raise exceptions.NetworkNotFound(net_id=net_id) raise exceptions.NetworkNotFound(net_id=net_id)
@@ -161,6 +162,7 @@ def update_subnet(context, id, subnet):
LOG.info("update_subnet %s for tenant %s" % LOG.info("update_subnet %s for tenant %s" %
(id, context.tenant_id)) (id, context.tenant_id))
with context.session.begin():
subnet_db = db_api.subnet_find(context, id=id, scope=db_api.ONE) subnet_db = db_api.subnet_find(context, id=id, scope=db_api.ONE)
if not subnet_db: if not subnet_db:
raise exceptions.SubnetNotFound(id=id) raise exceptions.SubnetNotFound(id=id)
@@ -292,6 +294,7 @@ def delete_subnet(context, id):
: param id: UUID representing the subnet to delete. : param id: UUID representing the subnet to delete.
""" """
LOG.info("delete_subnet %s for tenant %s" % (id, context.tenant_id)) LOG.info("delete_subnet %s for tenant %s" % (id, context.tenant_id))
with context.session.begin():
subnet = db_api.subnet_find(context, id=id, scope=db_api.ONE) subnet = db_api.subnet_find(context, id=id, scope=db_api.ONE)
if not subnet: if not subnet:
raise exceptions.SubnetNotFound(subnet_id=id) raise exceptions.SubnetNotFound(subnet_id=id)

View File

@@ -87,6 +87,7 @@ def _make_subnet_dict(subnet, default_route=None, fields=None):
"allocation_pools": _allocation_pools(subnet), "allocation_pools": _allocation_pools(subnet),
"dns_nameservers": dns_nameservers or [], "dns_nameservers": dns_nameservers or [],
"cidr": subnet.get("cidr"), "cidr": subnet.get("cidr"),
"shared": STRATEGY.is_parent_network(net_id),
"enable_dhcp": None} "enable_dhcp": None}
def _host_route(route): def _host_route(route):

View File

@@ -15,6 +15,7 @@
import contextlib import contextlib
import copy import copy
import time
import uuid import uuid
import mock import mock
@@ -697,7 +698,17 @@ class TestQuarkDeleteSubnet(test_quark_plugin.TestQuarkPlugin):
class TestSubnetsNotification(test_quark_plugin.TestQuarkPlugin): class TestSubnetsNotification(test_quark_plugin.TestQuarkPlugin):
@contextlib.contextmanager @contextlib.contextmanager
def _stubs(self, s, deleted_at=None): def _stubs(self, s, deleted_at=None):
class FakeContext(object):
def __enter__(*args, **kwargs):
pass
def __exit__(*args, **kwargs):
pass
self.context.session.begin = FakeContext
s["network"] = models.Network() s["network"] = models.Network()
s["network"]["created_at"] = s["created_at"]
subnet = models.Subnet(**s) subnet = models.Subnet(**s)
db_mod = "quark.db.api" db_mod = "quark.db.api"
api_mod = "neutron.openstack.common.notifier.api" api_mod = "neutron.openstack.common.notifier.api"
@@ -706,13 +717,15 @@ class TestSubnetsNotification(test_quark_plugin.TestQuarkPlugin):
mock.patch("%s.subnet_find" % db_mod), mock.patch("%s.subnet_find" % db_mod),
mock.patch("%s.network_find" % db_mod), mock.patch("%s.network_find" % db_mod),
mock.patch("%s.subnet_create" % db_mod), mock.patch("%s.subnet_create" % db_mod),
mock.patch("%s.ip_policy_create" % db_mod),
mock.patch("%s.subnet_delete" % db_mod), mock.patch("%s.subnet_delete" % db_mod),
mock.patch("%s.notify" % api_mod), mock.patch("%s.notify" % api_mod),
mock.patch("%s.utcnow" % time_mod) mock.patch("%s.utcnow" % time_mod)
) as (sub_find, net_find, sub_create, sub_del, notify, time): ) as (sub_find, net_find, sub_create, pol_cre, sub_del, notify,
time_func):
sub_create.return_value = subnet sub_create.return_value = subnet
sub_find.return_value = subnet sub_find.return_value = subnet
time.return_value = deleted_at time_func.return_value = deleted_at
yield notify yield notify
def test_create_subnet_notification(self): def test_create_subnet_notification(self):
@@ -730,8 +743,10 @@ class TestSubnetsNotification(test_quark_plugin.TestQuarkPlugin):
created_at=s["created_at"])) created_at=s["created_at"]))
def test_delete_subnet_notification(self): def test_delete_subnet_notification(self):
s = dict(tenant_id=1, id=1, created_at="123") now = time.strftime('%Y-%m-%d %H:%M:%S')
with self._stubs(s, deleted_at="456") as notify: later = time.strftime('%Y-%m-%d %H:%M:%S')
s = dict(tenant_id=1, id=1, created_at=now)
with self._stubs(s, deleted_at=later) as notify:
self.plugin.delete_subnet(self.context, 1) self.plugin.delete_subnet(self.context, 1)
notify.assert_called_once_with( notify.assert_called_once_with(
self.context, self.context,
@@ -741,7 +756,7 @@ class TestSubnetsNotification(test_quark_plugin.TestQuarkPlugin):
dict(tenant_id=s["tenant_id"], dict(tenant_id=s["tenant_id"],
created_at=s["created_at"], created_at=s["created_at"],
ip_block_id=s["id"], ip_block_id=s["id"],
deleted_at="456")) deleted_at=later))
class TestQuarkDiagnoseSubnets(test_quark_plugin.TestQuarkPlugin): class TestQuarkDiagnoseSubnets(test_quark_plugin.TestQuarkPlugin):

View File

@@ -24,3 +24,15 @@ class TestBase(unittest2.TestCase):
def setUp(self): def setUp(self):
super(TestBase, self).setUp() super(TestBase, self).setUp()
self.context = context.Context('fake', 'fake', is_admin=False) self.context = context.Context('fake', 'fake', is_admin=False)
class FakeContext(object):
def __new__(cls, *args, **kwargs):
return super(FakeContext, cls).__new__(cls)
def __enter__(*args, **kwargs):
pass
def __exit__(*args, **kwargs):
pass
self.context.session.begin = FakeContext

View File

@@ -202,6 +202,7 @@ class QuarkNewIPAddressAllocation(QuarkIpamBaseTest):
if not addresses: if not addresses:
addresses = [None] addresses = [None]
db_mod = "quark.db.api" db_mod = "quark.db.api"
self.context.session.add = mock.Mock()
with contextlib.nested( with contextlib.nested(
mock.patch("%s.ip_address_find" % db_mod), mock.patch("%s.ip_address_find" % db_mod),
mock.patch("%s.subnet_find_allocation_counts" % db_mod) mock.patch("%s.subnet_find_allocation_counts" % db_mod)
@@ -310,6 +311,7 @@ class QuarkIPAddressAllocateDeallocated(QuarkIpamBaseTest):
@contextlib.contextmanager @contextlib.contextmanager
def _stubs(self, ip_find, subnet, address, addresses_found): def _stubs(self, ip_find, subnet, address, addresses_found):
db_mod = "quark.db.api" db_mod = "quark.db.api"
self.context.session.add = mock.Mock()
with contextlib.nested( with contextlib.nested(
mock.patch("%s.ip_address_find" % db_mod), mock.patch("%s.ip_address_find" % db_mod),
mock.patch("%s.ip_address_update" % db_mod), mock.patch("%s.ip_address_update" % db_mod),
@@ -360,13 +362,17 @@ class QuarkIPAddressAllocateDeallocated(QuarkIpamBaseTest):
This edge case occurs because users are allowed to select a specific IP This edge case occurs because users are allowed to select a specific IP
address to create. address to create.
""" """
network_mod = models.Network()
network_mod.update(dict(ip_policy=None))
subnet = dict(id=1, ip_version=4, next_auto_assign_ip=0, subnet = dict(id=1, ip_version=4, next_auto_assign_ip=0,
cidr="0.0.0.0/24", first_ip=0, last_ip=255, cidr="0.0.0.0/24", first_ip=0, last_ip=255,
network=dict(ip_policy=None), ip_policy=None) network=network_mod, ip_policy=None)
address0 = dict(id=1, address=0) address0 = dict(id=1, address=0)
addresses_found = [None, None] addresses_found = [None, None]
subnet_mod = models.Subnet()
subnet_mod.update(subnet)
with self._stubs( with self._stubs(
False, subnet, address0, addresses_found False, subnet_mod, address0, addresses_found
) as (choose_subnet): ) as (choose_subnet):
ipaddress = self.ipam.allocate_ip_address(self.context, 0, 0, 0) ipaddress = self.ipam.allocate_ip_address(self.context, 0, 0, 0)
self.assertEqual(ipaddress["address"], 2) self.assertEqual(ipaddress["address"], 2)
@@ -380,6 +386,7 @@ class TestQuarkIpPoliciesIpAllocation(QuarkIpamBaseTest):
if not addresses: if not addresses:
addresses = [None] addresses = [None]
db_mod = "quark.db.api" db_mod = "quark.db.api"
self.context.session.add = mock.Mock()
with contextlib.nested( with contextlib.nested(
mock.patch("%s.ip_address_find" % db_mod), mock.patch("%s.ip_address_find" % db_mod),
mock.patch("%s.subnet_find_allocation_counts" % db_mod) mock.patch("%s.subnet_find_allocation_counts" % db_mod)
@@ -491,6 +498,7 @@ class QuarkIPAddressAllocationNotifications(QuarkIpamBaseTest):
db_mod = "quark.db.api" db_mod = "quark.db.api"
api_mod = "neutron.openstack.common.notifier.api" api_mod = "neutron.openstack.common.notifier.api"
time_mod = "neutron.openstack.common.timeutils" time_mod = "neutron.openstack.common.timeutils"
self.context.session.add = mock.Mock()
with contextlib.nested( with contextlib.nested(
mock.patch("%s.ip_address_find" % db_mod), mock.patch("%s.ip_address_find" % db_mod),
mock.patch("%s.ip_address_create" % db_mod), mock.patch("%s.ip_address_create" % db_mod),