Provide support for "All Tenants" access

This moves some of the per-tenant result filtering
into the storage layer, a necessary pre-requisite to supporting
both the v1 and v2 APIs for record access.

Again, Additionally, This fixes a number of issues we had with
our existing tests, where the handling of tenant_id's was
inconsistent.

Change-Id: I1a20a3c81edbf74a7d6256118d206e509c6da65c
This commit is contained in:
Kiall Mac Innes 2013-12-03 17:48:55 +00:00
parent 0d23371c87
commit aa780afb12
17 changed files with 530 additions and 347 deletions

View File

@ -23,7 +23,7 @@ from designate.context import DesignateContext
from designate.openstack.common import jsonutils as json
from designate.openstack.common import local
from designate.openstack.common import log as logging
from designate.openstack.common import uuidutils
from designate.openstack.common import strutils
from designate.openstack.common.rpc import common as rpc_common
LOG = logging.getLogger(__name__)
@ -102,10 +102,10 @@ class KeystoneContextMiddleware(ContextMiddleware):
try:
if headers['X-Identity-Status'] is 'Invalid':
#TODO(graham) fix the return to use non-flask resources
# TODO(graham) fix the return to use non-flask resources
return flask.Response(status=401)
except KeyError:
#If the key is valid, Keystone does not include this header at all
# If the key is valid, Keystone does not include this header at all
pass
if headers.get('X-Service-Catalog'):
@ -124,13 +124,6 @@ class KeystoneContextMiddleware(ContextMiddleware):
# Store the context where oslo-log exepcts to find it.
local.store.context = context
# Attempt to sudo, if requested.
sudo_tenant_id = headers.get('X-Designate-Sudo-Tenant-ID', None)
if sudo_tenant_id and (uuidutils.is_uuid_like(sudo_tenant_id)
or sudo_tenant_id.isdigit()):
context.sudo(sudo_tenant_id)
# Attach the context to the request environment
request.environ['context'] = context
@ -153,6 +146,34 @@ class NoAuthContextMiddleware(ContextMiddleware):
request.environ['context'] = context
class TestContextMiddleware(ContextMiddleware):
def __init__(self, application, tenant_id=None, user_id=None):
super(TestContextMiddleware, self).__init__(application)
LOG.critical('Starting designate testcontext middleware')
LOG.critical('**** DO NOT USE IN PRODUCTION ****')
self.default_tenant_id = tenant_id
self.default_user_id = user_id
def process_request(self, request):
headers = request.headers
all_tenants = strutils.bool_from_string(
headers.get('X-Test-All-Tenants', 'False'))
context = DesignateContext(
user=headers.get('X-Test-User-ID', self.default_user_id),
tenant=headers.get('X-Test-Tenant-ID', self.default_tenant_id),
all_tenants=all_tenants)
# Store the context where oslo-log exepcts to find it.
local.store.context = context
# Attach the context to the request environment
request.environ['context'] = context
class FaultWrapperMiddleware(wsgi.Middleware):
def __init__(self, application):
super(FaultWrapperMiddleware, self).__init__(application)

View File

@ -183,6 +183,9 @@ class Service(rpc_service.Service):
return False
def _is_subdomain(self, context, domain_name):
context = context.elevated()
context.all_tenants = True
# Break the name up into it's component labels
labels = domain_name.split(".")
@ -403,7 +406,10 @@ class Service(rpc_service.Service):
# Domain Methods
def create_domain(self, context, values):
# TODO(kiall): Refactor this method into *MUCH* smaller chunks.
values['tenant_id'] = context.tenant_id
# Default to creating in the current users tenant
if 'tenant_id' not in values:
values['tenant_id'] = context.tenant_id
target = {
'tenant_id': values['tenant_id'],
@ -483,21 +489,12 @@ class Service(rpc_service.Service):
target = {'tenant_id': context.tenant_id}
policy.check('find_domains', context, target)
if criterion is None:
criterion = {}
if not context.is_admin:
criterion['tenant_id'] = context.tenant_id
return self.storage_api.find_domains(context, criterion)
def find_domain(self, context, criterion):
target = {'tenant_id': context.tenant_id}
policy.check('find_domain', context, target)
if not context.is_admin:
criterion['tenant_id'] = context.tenant_id
return self.storage_api.find_domain(context, criterion)
def update_domain(self, context, domain_id, values, increment_serial=True):

View File

@ -16,7 +16,6 @@
import itertools
from designate.openstack.common import context
from designate.openstack.common import log as logging
from designate import policy
LOG = logging.getLogger(__name__)
@ -24,7 +23,7 @@ LOG = logging.getLogger(__name__)
class DesignateContext(context.RequestContext):
def __init__(self, auth_token=None, user=None, tenant=None, is_admin=False,
read_only=False, show_deleted=False, request_id=None,
original_tenant_id=None, roles=[], service_catalog=None):
roles=[], service_catalog=None, all_tenants=False):
super(DesignateContext, self).__init__(
auth_token=auth_token,
user=user,
@ -34,29 +33,9 @@ class DesignateContext(context.RequestContext):
show_deleted=show_deleted,
request_id=request_id)
self._original_tenant_id = original_tenant_id
self.roles = roles
self.service_catalog = service_catalog
def sudo(self, tenant_id, force=False):
if force:
allowed_sudo = True
else:
# We use exc=None here since the context is built early in the
# request lifecycle, outside of our ordinary error handling.
# For now, we silently ignore failed sudo requests.
target = {'tenant_id': tenant_id}
allowed_sudo = policy.check('use_sudo', self, target, exc=None)
if allowed_sudo:
LOG.warn('Accepted sudo from user_id %s for tenant_id %s'
% (self.user_id, tenant_id))
self.original_tenant_id = self.tenant_id
self.tenant_id = tenant_id
else:
LOG.warn('Rejected sudo from user_id %s for tenant_id %s'
% (self.user_id, tenant_id))
self.all_tenants = all_tenants
def deepcopy(self):
d = self.to_dict()
@ -73,9 +52,9 @@ class DesignateContext(context.RequestContext):
d.update({
'user_id': self.user_id,
'tenant_id': self.tenant_id,
'original_tenant_id': self.original_tenant_id,
'roles': self.roles,
'service_catalog': self.service_catalog
'service_catalog': self.service_catalog,
'all_tenants': self.all_tenants,
})
return d
@ -113,17 +92,6 @@ class DesignateContext(context.RequestContext):
def tenant_id(self, value):
self.tenant = value
@property
def original_tenant_id(self):
if self._original_tenant_id:
return self._original_tenant_id
else:
return self.tenant
@original_tenant_id.setter
def original_tenant_id(self, value):
self._original_tenant_id = value
@classmethod
def get_admin_context(cls, **kwargs):
# TODO(kiall): Remove Me

View File

@ -69,7 +69,7 @@ class Handler(Plugin):
"""
Return the domain for this context
"""
context = DesignateContext.get_admin_context()
context = DesignateContext.get_admin_context(all_tenants=True)
return central_api.get_domain(context, domain_id)
@ -99,7 +99,7 @@ class BaseAddressHandler(Handler):
LOG.debug('Event data: %s' % data)
data['domain'] = domain['name']
context = DesignateContext.get_admin_context()
context = DesignateContext.get_admin_context(all_tenants=True)
for addr in addresses:
record_data = data.copy()
@ -128,7 +128,7 @@ class BaseAddressHandler(Handler):
:param criterion: Criterion to search and destroy records
"""
context = DesignateContext.get_admin_context()
context = DesignateContext.get_admin_context(all_tenants=True)
if managed:
criterion.update({

View File

@ -0,0 +1,133 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2011 OpenStack Foundation.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
System-level utilities and helper functions.
"""
import logging
import sys
LOG = logging.getLogger(__name__)
def int_from_bool_as_string(subject):
"""
Interpret a string as a boolean and return either 1 or 0.
Any string value in:
('True', 'true', 'On', 'on', '1')
is interpreted as a boolean True.
Useful for JSON-decoded stuff and config file parsing
"""
return bool_from_string(subject) and 1 or 0
def bool_from_string(subject):
"""
Interpret a string as a boolean.
Any string value in:
('True', 'true', 'On', 'on', 'Yes', 'yes', '1')
is interpreted as a boolean True.
Useful for JSON-decoded stuff and config file parsing
"""
if isinstance(subject, bool):
return subject
if isinstance(subject, basestring):
if subject.strip().lower() in ('true', 'on', 'yes', '1'):
return True
return False
def safe_decode(text, incoming=None, errors='strict'):
"""
Decodes incoming str using `incoming` if they're
not already unicode.
:param incoming: Text's current encoding
:param errors: Errors handling policy. See here for valid
values http://docs.python.org/2/library/codecs.html
:returns: text or a unicode `incoming` encoded
representation of it.
:raises TypeError: If text is not an isntance of basestring
"""
if not isinstance(text, basestring):
raise TypeError("%s can't be decoded" % type(text))
if isinstance(text, unicode):
return text
if not incoming:
incoming = (sys.stdin.encoding or
sys.getdefaultencoding())
try:
return text.decode(incoming, errors)
except UnicodeDecodeError:
# Note(flaper87) If we get here, it means that
# sys.stdin.encoding / sys.getdefaultencoding
# didn't return a suitable encoding to decode
# text. This happens mostly when global LANG
# var is not set correctly and there's no
# default encoding. In this case, most likely
# python will use ASCII or ANSI encoders as
# default encodings but they won't be capable
# of decoding non-ASCII characters.
#
# Also, UTF-8 is being used since it's an ASCII
# extension.
return text.decode('utf-8', errors)
def safe_encode(text, incoming=None,
encoding='utf-8', errors='strict'):
"""
Encodes incoming str/unicode using `encoding`. If
incoming is not specified, text is expected to
be encoded with current python's default encoding.
(`sys.getdefaultencoding`)
:param incoming: Text's current encoding
:param encoding: Expected encoding for text (Default UTF-8)
:param errors: Errors handling policy. See here for valid
values http://docs.python.org/2/library/codecs.html
:returns: text or a bytestring `encoding` encoded
representation of it.
:raises TypeError: If text is not an isntance of basestring
"""
if not isinstance(text, basestring):
raise TypeError("%s can't be encoded" % type(text))
if not incoming:
incoming = (sys.stdin.encoding or
sys.getdefaultencoding())
if isinstance(text, unicode):
return text.encode(encoding, errors)
elif text and encoding != incoming:
# Decode text before encoding it with `encoding`
text = safe_decode(text, incoming, errors)
return text.encode(encoding, errors)
return text

View File

@ -40,6 +40,9 @@ class StorageQuota(Quota):
return dict((q['resource'], q['hard_limit']) for q in quotas)
def get_quota(self, context, tenant_id, resource):
context = context.deepcopy()
context.all_tenants = True
quota = self.storage_api.find_quota(context, {
'tenant_id': tenant_id,
'resource': resource,
@ -48,6 +51,9 @@ class StorageQuota(Quota):
return {resource: quota['hard_limit']}
def set_quota(self, context, tenant_id, resource, hard_limit):
context = context.deepcopy()
context.all_tenants = True
def create_quota():
values = {
'tenant_id': tenant_id,
@ -81,6 +87,9 @@ class StorageQuota(Quota):
return {resource: hard_limit}
def reset_quotas(self, context, tenant_id):
context = context.deepcopy()
context.all_tenants = True
quotas = self.storage_api.find_quotas(context, {
'tenant_id': tenant_id,
})

View File

@ -65,6 +65,15 @@ class SQLAlchemyStorage(base.Storage):
return query
def _apply_tenant_criteria(self, context, model, query):
if hasattr(model, 'tenant_id'):
if context.all_tenants:
LOG.debug('Including all tenants items in query results')
else:
query = query.filter(model.tenant_id == context.tenant_id)
return query
def _apply_deleted_criteria(self, context, model, query):
if issubclass(model, SoftDeleteMixin):
if context.show_deleted:
@ -83,6 +92,7 @@ class SQLAlchemyStorage(base.Storage):
# First up, create a query and apply the various filters
query = self.session.query(model)
query = self._apply_criterion(model, query, criterion)
query = self._apply_tenant_criteria(context, model, query)
query = self._apply_deleted_criteria(context, model, query)
if one:
@ -255,6 +265,7 @@ class SQLAlchemyStorage(base.Storage):
# returns an array of tenant_id & count of their domains
query = self.session.query(models.Domain.tenant_id,
func.count(models.Domain.id))
query = self._apply_tenant_criteria(context, models.Domain, query)
query = self._apply_deleted_criteria(context, models.Domain, query)
query = query.group_by(models.Domain.tenant_id)
@ -263,8 +274,9 @@ class SQLAlchemyStorage(base.Storage):
def get_tenant(self, context, tenant_id):
# get list list & count of all domains owned by given tenant_id
query = self.session.query(models.Domain.name)
query = query.filter(models.Domain.tenant_id == tenant_id)
query = self._apply_tenant_criteria(context, models.Domain, query)
query = self._apply_deleted_criteria(context, models.Domain, query)
query = query.filter(models.Domain.tenant_id == tenant_id)
result = query.all()
@ -278,6 +290,7 @@ class SQLAlchemyStorage(base.Storage):
# tenants are the owner of domains, count the number of unique tenants
# select count(distinct tenant_id) from domains
query = self.session.query(distinct(models.Domain.tenant_id))
query = self._apply_tenant_criteria(context, models.Domain, query)
query = self._apply_deleted_criteria(context, models.Domain, query)
return query.count()
@ -339,6 +352,7 @@ class SQLAlchemyStorage(base.Storage):
def count_domains(self, context, criterion=None):
query = self.session.query(models.Domain)
query = self._apply_criterion(models.Domain, query, criterion)
query = self._apply_tenant_criteria(context, models.Domain, query)
query = self._apply_deleted_criteria(context, models.Domain, query)
return query.count()
@ -410,6 +424,7 @@ class SQLAlchemyStorage(base.Storage):
def count_records(self, context, criterion=None):
query = self.session.query(models.Record)
query = self._apply_tenant_criteria(context, models.Record, query)
query = self._apply_criterion(models.Record, query, criterion)
return query.count()

View File

@ -93,11 +93,9 @@ class PolicyFixture(fixtures.Fixture):
class TestCase(test.BaseTestCase):
quota_fixtures = [{
'tenant_id': '12345',
'resource': 'domains',
'hard_limit': 5,
}, {
'tenant_id': '12345',
'resource': 'records',
'hard_limit': 50,
}]

View File

@ -131,38 +131,6 @@ class KeystoneContextMiddlewareTest(ApiTestCase):
self.assertEqual('TenantID', context.tenant_id)
self.assertEqual(['admin', 'Member'], context.roles)
def test_process_request_sudo(self):
# Set the policy to accept the authz
self.policy({'use_sudo': '@'})
app = middleware.KeystoneContextMiddleware({})
request = FakeRequest()
request.headers = {
'X-Auth-Token': 'AuthToken',
'X-User-ID': 'UserID',
'X-Tenant-ID': 'TenantID',
'X-Roles': 'admin,Member',
'X-Designate-Sudo-Tenant-ID':
'5a993bf8-d521-420a-81e1-192d9cc3d5a0'
}
# Process the request
app.process_request(request)
self.assertIn('context', request.environ)
context = request.environ['context']
self.assertFalse(context.is_admin)
self.assertEqual('AuthToken', context.auth_token)
self.assertEqual('UserID', context.user_id)
self.assertEqual('TenantID', context.original_tenant_id)
self.assertEqual('5a993bf8-d521-420a-81e1-192d9cc3d5a0',
context.tenant_id)
self.assertEqual(['admin', 'Member'], context.roles)
def test_process_request_invalid_keystone_token(self):
app = middleware.KeystoneContextMiddleware({})

View File

@ -37,9 +37,10 @@ class ApiV1Test(ApiTestCase):
self.app.wsgi_app = middleware.FaultWrapperMiddleware(
self.app.wsgi_app)
# Inject the NoAuth middleware
self.app.wsgi_app = middleware.NoAuthContextMiddleware(
self.app.wsgi_app)
# Inject the TestAuth middleware
self.app.wsgi_app = middleware.TestContextMiddleware(
self.app.wsgi_app, self.admin_context.tenant_id,
self.admin_context.user_id)
# Obtain a test client
self.client = self.app.test_client()

View File

@ -25,6 +25,13 @@ LOG = logging.getLogger(__name__)
class ApiV1ServersTest(ApiV1Test):
def setUp(self):
super(ApiV1ServersTest, self).setUp()
# All Server Checks should be performed as an admin, so..
# Override to policy to make everyone an admin.
self.policy({'admin': '@'})
def test_create_server(self):
# Create a server
fixture = self.get_server_fixture(0)

View File

@ -36,8 +36,10 @@ class ApiV2TestCase(ApiTestCase):
# Inject the FaultWrapper middleware
self.app = middleware.FaultWrapperMiddleware(self.app)
# Inject the NoAuth middleware
self.app = middleware.NoAuthContextMiddleware(self.app)
# Inject the TestContext middleware
self.app = middleware.TestContextMiddleware(
self.app, self.admin_context.tenant_id,
self.admin_context.tenant_id)
# Obtain a test client
self.client = TestApp(self.app)

View File

@ -295,18 +295,21 @@ class CentralServiceTest(CentralTestCase):
# Tenant Tests
def test_count_tenants(self):
context = self.get_admin_context()
admin_context = self.get_admin_context()
admin_context.all_tenants = True
tenant_one_context = self.get_context(tenant=1)
tenant_two_context = self.get_context(tenant=2)
# in the beginning, there should be nothing
tenants = self.central_service.count_tenants(self.admin_context)
tenants = self.central_service.count_tenants(admin_context)
self.assertEqual(tenants, 0)
# Explicitly set a tenant_id
context.tenant_id = '1'
self.create_domain(fixture=0, context=context)
context.tenant_id = '2'
self.create_domain(fixture=1, context=context)
self.create_domain(fixture=0, context=tenant_one_context)
self.create_domain(fixture=1, context=tenant_two_context)
tenants = self.central_service.count_tenants(self.admin_context)
tenants = self.central_service.count_tenants(admin_context)
self.assertEqual(tenants, 2)
def test_count_tenants_policy_check(self):
@ -398,20 +401,16 @@ class CentralServiceTest(CentralTestCase):
self.create_domain()
def test_create_subdomain(self):
context = self.get_admin_context()
# Explicitly set a tenant_id
context.tenant_id = '1'
# Create the Parent Domain using fixture 0
parent_domain = self.create_domain(fixture=0, context=context)
parent_domain = self.create_domain(fixture=0)
# Prepare values for the subdomain using fixture 1 as a base
values = self.get_domain_fixture(1)
values['name'] = 'www.%s' % parent_domain['name']
# Create the subdomain
domain = self.central_service.create_domain(context, values=values)
domain = self.central_service.create_domain(
self.admin_context, values=values)
# Ensure all values have been set correctly
self.assertIsNotNone(domain['id'])
@ -446,18 +445,17 @@ class CentralServiceTest(CentralTestCase):
# Set the policy to accept the authz
self.policy({'use_blacklisted_domain': '@'})
# Create a server
self.create_server()
context = self.get_admin_context()
values = dict(
name='blacklisted.com.',
email='info@blacklisted.com'
)
# Create a server
self.create_server()
# Create a domain
domain = self.central_service.create_domain(context, values=values)
domain = self.central_service.create_domain(
self.admin_context, values=values)
# Ensure all values have been set correctly
self.assertIsNotNone(domain['id'])
@ -471,8 +469,6 @@ class CentralServiceTest(CentralTestCase):
# Set the policy to reject the authz
self.policy({'use_blacklisted_domain': '!'})
context = self.get_admin_context()
values = dict(
name='blacklisted.com.',
email='info@blacklisted.com'
@ -480,7 +476,8 @@ class CentralServiceTest(CentralTestCase):
with testtools.ExpectedException(exceptions.InvalidDomainName):
# Create a domain
self.central_service.create_domain(context, values=values)
self.central_service.create_domain(
self.admin_context, values=values)
def _test_create_domain_fail(self, values, exception):
self.config(accepted_tlds_file='tlds-alpha-by-domain.txt.sample',
@ -494,10 +491,10 @@ class CentralServiceTest(CentralTestCase):
self.central_service.effective_tld._load_accepted_tld_list()
self.central_service.effective_tld._load_effective_tld_list()
context = self.get_admin_context()
with testtools.ExpectedException(exception):
# Create an invalid domain
self.central_service.create_domain(context, values=values)
self.central_service.create_domain(
self.admin_context, values=values)
def test_create_domain_invalid_tld_fail(self):
self.config(accepted_tlds_file='tlds-alpha-by-domain.txt.sample',
@ -511,8 +508,6 @@ class CentralServiceTest(CentralTestCase):
self.central_service.effective_tld._load_accepted_tld_list()
self.central_service.effective_tld._load_effective_tld_list()
context = self.get_admin_context()
# Create a server
self.create_server()
@ -522,7 +517,7 @@ class CentralServiceTest(CentralTestCase):
)
# Create a valid domain
self.central_service.create_domain(context, values=values)
self.central_service.create_domain(self.admin_context, values=values)
values = dict(
name='invalid.NeT1.',
@ -531,7 +526,8 @@ class CentralServiceTest(CentralTestCase):
with testtools.ExpectedException(exceptions.InvalidTLD):
# Create an invalid domain
self.central_service.create_domain(context, values=values)
self.central_service.create_domain(
self.admin_context, values=values)
def test_create_domain_effective_tld_fail(self):
values = dict(
@ -539,8 +535,8 @@ class CentralServiceTest(CentralTestCase):
email='info@invalid.com'
)
self._test_create_domain_fail(values,
exceptions.DomainIsSameAsAnEffectiveTLD)
self._test_create_domain_fail(
values, exceptions.DomainIsSameAsAnEffectiveTLD)
def test_idn_create_domain_effective_tld_fail(self):
# Test creation of the effective TLD - brønnøysund.no
@ -549,8 +545,8 @@ class CentralServiceTest(CentralTestCase):
email='info@invalid.com'
)
self._test_create_domain_fail(values,
exceptions.DomainIsSameAsAnEffectiveTLD)
self._test_create_domain_fail(
values, exceptions.DomainIsSameAsAnEffectiveTLD)
def test_create_domain_re_effective_tld_fail(self):
# co.uk is in the regular expression list for effective_tlds
@ -559,21 +555,19 @@ class CentralServiceTest(CentralTestCase):
email='info@invalid.com'
)
self._test_create_domain_fail(values,
exceptions.DomainIsSameAsAnEffectiveTLD)
self._test_create_domain_fail(
values, exceptions.DomainIsSameAsAnEffectiveTLD)
def test_find_domains(self):
context = self.get_admin_context()
# Ensure we have no domains to start with.
domains = self.central_service.find_domains(context)
domains = self.central_service.find_domains(self.admin_context)
self.assertEqual(len(domains), 0)
# Create a single domain (using default values)
self.create_domain()
# Ensure we can retrieve the newly created domain
domains = self.central_service.find_domains(context)
domains = self.central_service.find_domains(self.admin_context)
self.assertEqual(len(domains), 1)
self.assertEqual(domains[0]['name'], 'example.com.')
@ -581,27 +575,30 @@ class CentralServiceTest(CentralTestCase):
self.create_domain(name='example.net.')
# Ensure we can retrieve both domain
domains = self.central_service.find_domains(context)
domains = self.central_service.find_domains(self.admin_context)
self.assertEqual(len(domains), 2)
self.assertEqual(domains[0]['name'], 'example.com.')
self.assertEqual(domains[1]['name'], 'example.net.')
def test_find_domains_criteria(self):
context = self.get_admin_context()
# Create a domain
domain_name = '%d.example.com.' % random.randint(10, 1000)
expected_domain = self.create_domain(name=domain_name)
# Retrieve it, and ensure it's the same
criterion = {'name': domain_name}
domains = self.central_service.find_domains(context, criterion)
domains = self.central_service.find_domains(
self.admin_context, criterion)
self.assertEqual(domains[0]['id'], expected_domain['id'])
self.assertEqual(domains[0]['name'], expected_domain['name'])
self.assertEqual(domains[0]['email'], expected_domain['email'])
def test_find_domains_tenant_restrictions(self):
admin_context = self.get_admin_context()
admin_context.all_tenants = True
tenant_one_context = self.get_context(tenant=1)
tenant_two_context = self.get_context(tenant=2)
@ -610,65 +607,62 @@ class CentralServiceTest(CentralTestCase):
self.assertEqual(len(domains), 0)
# Create a single domain (using default values)
self.create_domain(context=tenant_one_context)
domain = self.create_domain(context=tenant_one_context)
# Ensure admins can retrieve the newly created domain
domains = self.central_service.find_domains(admin_context)
self.assertEqual(len(domains), 1)
self.assertEqual(domains[0]['name'], 'example.com.')
self.assertEqual(domains[0]['name'], domain['name'])
# Ensure tenant=1 can retrieve the newly created domain
domains = self.central_service.find_domains(tenant_one_context)
self.assertEqual(len(domains), 1)
self.assertEqual(domains[0]['name'], 'example.com.')
self.assertEqual(domains[0]['name'], domain['name'])
# Ensure tenant=2 can NOT retrieve the newly created domain
domains = self.central_service.find_domains(tenant_two_context)
self.assertEqual(len(domains), 0)
def test_get_domain(self):
context = self.get_admin_context()
# Create a domain
domain_name = '%d.example.com.' % random.randint(10, 1000)
expected_domain = self.create_domain(name=domain_name)
# Retrieve it, and ensure it's the same
domain = self.central_service.get_domain(context,
expected_domain['id'])
domain = self.central_service.get_domain(
self.admin_context, expected_domain['id'])
self.assertEqual(domain['id'], expected_domain['id'])
self.assertEqual(domain['name'], expected_domain['name'])
self.assertEqual(domain['email'], expected_domain['email'])
def test_get_domain_servers(self):
context = self.get_admin_context()
# Create a domain
domain = self.create_domain()
# Retrieve the servers list
servers = self.central_service.get_domain_servers(context,
domain['id'])
servers = self.central_service.get_domain_servers(
self.admin_context, domain['id'])
self.assertTrue(len(servers) > 0)
def test_find_domain(self):
context = self.get_admin_context()
# Create a domain
domain_name = '%d.example.com.' % random.randint(10, 1000)
expected_domain = self.create_domain(name=domain_name)
# Retrieve it, and ensure it's the same
criterion = {'name': domain_name}
domain = self.central_service.find_domain(context, criterion)
domain = self.central_service.find_domain(
self.admin_context, criterion)
self.assertEqual(domain['id'], expected_domain['id'])
self.assertEqual(domain['name'], expected_domain['name'])
self.assertEqual(domain['email'], expected_domain['email'])
self.assertIn('status', domain)
def test_update_domain(self):
context = self.get_admin_context()
# Create a domain
expected_domain = self.create_domain()
@ -677,12 +671,13 @@ class CentralServiceTest(CentralTestCase):
# Update the domain
values = dict(email='new@example.com')
self.central_service.update_domain(context, expected_domain['id'],
values=values)
self.central_service.update_domain(
self.admin_context, expected_domain['id'], values=values)
# Fetch the domain again
domain = self.central_service.get_domain(context,
expected_domain['id'])
domain = self.central_service.get_domain(
self.admin_context, expected_domain['id'])
# Ensure the domain was updated correctly
self.assertTrue(domain['serial'] > expected_domain['serial'])
@ -706,40 +701,36 @@ class CentralServiceTest(CentralTestCase):
self.assertEqual(payload['tenant_id'], domain['tenant_id'])
def test_update_domain_without_incrementing_serial(self):
context = self.get_admin_context()
# Create a domain
expected_domain = self.create_domain()
# Update the domain
values = dict(email='new@example.com')
self.central_service.update_domain(context, expected_domain['id'],
values=values,
increment_serial=False)
self.central_service.update_domain(
self.admin_context, expected_domain['id'], values=values,
increment_serial=False)
# Fetch the domain again
domain = self.central_service.get_domain(context,
expected_domain['id'])
domain = self.central_service.get_domain(
self.admin_context, expected_domain['id'])
# Ensure the domain was updated correctly
self.assertEqual(domain['serial'], expected_domain['serial'])
self.assertEqual(domain['email'], 'new@example.com')
def test_update_domain_name_fail(self):
context = self.get_admin_context()
# Create a domain
expected_domain = self.create_domain()
# Update the domain
with testtools.ExpectedException(exceptions.BadRequest):
values = dict(name='renamed-domain.com.')
self.central_service.update_domain(context, expected_domain['id'],
values=values)
self.central_service.update_domain(
self.admin_context, expected_domain['id'], values=values)
def test_delete_domain(self):
context = self.get_admin_context()
# Create a domain
domain = self.create_domain()
@ -747,11 +738,11 @@ class CentralServiceTest(CentralTestCase):
self.reset_notifications()
# Delete the domain
self.central_service.delete_domain(context, domain['id'])
self.central_service.delete_domain(self.admin_context, domain['id'])
# Fetch the domain again, ensuring an exception is raised
with testtools.ExpectedException(exceptions.DomainNotFound):
self.central_service.get_domain(context, domain['id'])
self.central_service.get_domain(self.admin_context, domain['id'])
# Ensure we sent exactly 1 notification
notifications = self.get_notifications()
@ -771,8 +762,6 @@ class CentralServiceTest(CentralTestCase):
self.assertEqual(payload['tenant_id'], domain['tenant_id'])
def test_delete_parent_domain(self):
context = self.get_admin_context()
# Create the Parent Domain using fixture 0
parent_domain = self.create_domain(fixture=0)
@ -781,7 +770,8 @@ class CentralServiceTest(CentralTestCase):
# Attempt to delete the parent domain
with testtools.ExpectedException(exceptions.DomainHasSubdomain):
self.central_service.delete_domain(context, parent_domain['id'])
self.central_service.delete_domain(
self.admin_context, parent_domain['id'])
def test_count_domains(self):
# in the beginning, there should be nothing
@ -805,24 +795,22 @@ class CentralServiceTest(CentralTestCase):
self.central_service.count_domains(self.get_context())
def test_touch_domain(self):
context = self.get_admin_context()
# Create a domain
expected_domain = self.create_domain()
# Touch the domain
self.central_service.touch_domain(context, expected_domain['id'])
self.central_service.touch_domain(
self.admin_context, expected_domain['id'])
# Fetch the domain again
domain = self.central_service.get_domain(context,
expected_domain['id'])
domain = self.central_service.get_domain(
self.admin_context, expected_domain['id'])
# Ensure the serial was incremented
self.assertTrue(domain['serial'] > expected_domain['serial'])
# Record Tests
def test_create_record(self):
context = self.get_admin_context()
domain = self.create_domain()
values = dict(
@ -832,8 +820,8 @@ class CentralServiceTest(CentralTestCase):
)
# Create a record
record = self.central_service.create_record(context, domain['id'],
values=values)
record = self.central_service.create_record(
self.admin_context, domain['id'], values=values)
# Ensure all values have been set correctly
self.assertIsNotNone(record['id'])
@ -854,7 +842,6 @@ class CentralServiceTest(CentralTestCase):
self.create_record(domain)
def test_create_record_without_incrementing_serial(self):
context = self.get_admin_context()
domain = self.create_domain()
values = dict(
@ -864,9 +851,9 @@ class CentralServiceTest(CentralTestCase):
)
# Create a record
record = self.central_service.create_record(context, domain['id'],
values=values,
increment_serial=False)
record = self.central_service.create_record(
self.admin_context, domain['id'], values=values,
increment_serial=False)
# Ensure all values have been set correctly
self.assertIsNotNone(record['id'])
@ -876,11 +863,12 @@ class CentralServiceTest(CentralTestCase):
self.assertEqual(record['data'], values['data'])
# Ensure the domains serial number was not updated
updated_domain = self.central_service.get_domain(context, domain['id'])
updated_domain = self.central_service.get_domain(
self.admin_context, domain['id'])
self.assertEqual(domain['serial'], updated_domain['serial'])
def test_create_cname_record_at_apex(self):
context = self.get_admin_context()
domain = self.create_domain()
values = dict(
@ -891,11 +879,10 @@ class CentralServiceTest(CentralTestCase):
# Attempt to create a CNAME record at the apex
with testtools.ExpectedException(exceptions.InvalidRecordLocation):
self.central_service.create_record(context, domain['id'],
values=values)
self.central_service.create_record(
self.admin_context, domain['id'], values=values)
def test_create_cname_record_above_an_a_record(self):
context = self.get_admin_context()
domain = self.create_domain()
values = dict(
@ -904,8 +891,8 @@ class CentralServiceTest(CentralTestCase):
data='127.0.0.1'
)
self.central_service.create_record(context, domain['id'],
values=values)
self.central_service.create_record(
self.admin_context, domain['id'], values=values)
# Create a CNAME record alongside an A record
values = dict(
@ -914,13 +901,12 @@ class CentralServiceTest(CentralTestCase):
data='example.org.'
)
record = self.central_service.create_record(context, domain['id'],
values=values)
record = self.central_service.create_record(
self.admin_context, domain['id'], values=values)
self.assertIn('id', record)
def test_create_cname_record_below_an_a_record(self):
context = self.get_admin_context()
domain = self.create_domain()
values = dict(
@ -929,8 +915,8 @@ class CentralServiceTest(CentralTestCase):
data='127.0.0.1'
)
self.central_service.create_record(context, domain['id'],
values=values)
self.central_service.create_record(
self.admin_context, domain['id'], values=values)
# Create a CNAME record alongside an A record
values = dict(
@ -939,13 +925,12 @@ class CentralServiceTest(CentralTestCase):
data='example.org.'
)
record = self.central_service.create_record(context, domain['id'],
values=values)
record = self.central_service.create_record(
self.admin_context, domain['id'], values=values)
self.assertIn('id', record)
def test_create_cname_record_alongside_an_a_record(self):
context = self.get_admin_context()
domain = self.create_domain()
values = dict(
@ -954,8 +939,8 @@ class CentralServiceTest(CentralTestCase):
data='127.0.0.1'
)
self.central_service.create_record(context, domain['id'],
values=values)
self.central_service.create_record(
self.admin_context, domain['id'], values=values)
# Attempt to create a CNAME record alongside an A record
with testtools.ExpectedException(exceptions.InvalidRecordLocation):
@ -965,11 +950,10 @@ class CentralServiceTest(CentralTestCase):
data='example.org.'
)
self.central_service.create_record(context, domain['id'],
values=values)
self.central_service.create_record(
self.admin_context, domain['id'], values=values)
def test_create_an_a_record_alongside_a_cname_record(self):
context = self.get_admin_context()
domain = self.create_domain()
values = dict(
@ -978,8 +962,8 @@ class CentralServiceTest(CentralTestCase):
data='example.org.'
)
self.central_service.create_record(context, domain['id'],
values=values)
self.central_service.create_record(
self.admin_context, domain['id'], values=values)
# Attempt to create a CNAME record alongside an A record
with testtools.ExpectedException(exceptions.InvalidRecordLocation):
@ -989,11 +973,10 @@ class CentralServiceTest(CentralTestCase):
data='127.0.0.1'
)
self.central_service.create_record(context, domain['id'],
values=values)
self.central_service.create_record(
self.admin_context, domain['id'], values=values)
def test_create_duplicate_ptr_record(self):
context = self.get_admin_context()
domain = self.create_domain(values={'name': '2.0.192.in-addr.arpa.'})
values = dict(
@ -1002,8 +985,8 @@ class CentralServiceTest(CentralTestCase):
data='www.example.org.'
)
self.central_service.create_record(context, domain['id'],
values=values)
self.central_service.create_record(
self.admin_context, domain['id'], values=values)
# Attempt to create a second PTR with the same name.
with testtools.ExpectedException(exceptions.DuplicateRecord):
@ -1013,22 +996,25 @@ class CentralServiceTest(CentralTestCase):
data='www.example.com.'
)
self.central_service.create_record(context, domain['id'],
values=values)
self.central_service.create_record(
self.admin_context, domain['id'], values=values)
def test_find_records(self):
context = self.get_admin_context()
domain = self.create_domain()
# Ensure we have no records to start with.
records = self.central_service.find_records(context, domain['id'])
records = self.central_service.find_records(
self.admin_context, domain['id'])
self.assertEqual(len(records), 0)
# Create a single record (using default values)
self.create_record(domain)
# Ensure we can retrieve the newly created record
records = self.central_service.find_records(context, domain['id'])
records = self.central_service.find_records(
self.admin_context, domain['id'])
self.assertEqual(len(records), 1)
self.assertEqual(records[0]['name'], 'www.%s' % domain['name'])
@ -1036,13 +1022,14 @@ class CentralServiceTest(CentralTestCase):
self.create_record(domain, name='mail.%s' % domain['name'])
# Ensure we can retrieve both records
records = self.central_service.find_records(context, domain['id'])
records = self.central_service.find_records(
self.admin_context, domain['id'])
self.assertEqual(len(records), 2)
self.assertEqual(records[0]['name'], 'www.%s' % domain['name'])
self.assertEqual(records[1]['name'], 'mail.%s' % domain['name'])
def test_get_record(self):
context = self.get_admin_context()
domain = self.create_domain()
# Create a record
@ -1050,14 +1037,14 @@ class CentralServiceTest(CentralTestCase):
expected_record = self.create_record(domain, name=record_name)
# Retrieve it, and ensure it's the same
record = self.central_service.get_record(context, domain['id'],
expected_record['id'])
record = self.central_service.get_record(
self.admin_context, domain['id'], expected_record['id'])
self.assertEqual(record['id'], expected_record['id'])
self.assertEqual(record['name'], expected_record['name'])
self.assertIn('status', record)
def test_find_record(self):
context = self.get_admin_context()
domain = self.create_domain()
# Create a record
@ -1066,14 +1053,15 @@ class CentralServiceTest(CentralTestCase):
# Retrieve it, and ensure it's the same
criterion = {'name': record_name}
record = self.central_service.find_record(context, domain['id'],
criterion)
record = self.central_service.find_record(
self.admin_context, domain['id'], criterion)
self.assertEqual(record['id'], expected_record['id'])
self.assertEqual(record['name'], expected_record['name'])
self.assertIn('status', record)
def test_get_record_incorrect_domain_id(self):
context = self.get_admin_context()
domain = self.create_domain()
other_domain = self.create_domain(fixture=1)
@ -1083,11 +1071,10 @@ class CentralServiceTest(CentralTestCase):
# Ensure we get a 404 if we use the incorrect domain_id
with testtools.ExpectedException(exceptions.RecordNotFound):
self.central_service.get_record(context, other_domain['id'],
expected_record['id'])
self.central_service.get_record(
self.admin_context, other_domain['id'], expected_record['id'])
def test_update_record(self):
context = self.get_admin_context()
domain = self.create_domain()
# Create a record
@ -1095,48 +1082,48 @@ class CentralServiceTest(CentralTestCase):
# Update the record
values = dict(data='127.0.0.2')
self.central_service.update_record(context, domain['id'],
expected_record['id'],
values=values)
self.central_service.update_record(
self.admin_context, domain['id'], expected_record['id'],
values=values)
# Fetch the record again
record = self.central_service.get_record(context, domain['id'],
expected_record['id'])
record = self.central_service.get_record(
self.admin_context, domain['id'], expected_record['id'])
# Ensure the record was updated correctly
self.assertEqual(record['data'], '127.0.0.2')
def test_update_record_without_incrementing_serial(self):
context = self.get_admin_context()
domain = self.create_domain()
# Create a record
expected_record = self.create_record(domain)
# Fetch the domain so we have the latest serial number
domain_before = self.central_service.get_domain(context, domain['id'])
domain_before = self.central_service.get_domain(
self.admin_context, domain['id'])
# Update the record
values = dict(data='127.0.0.2')
self.central_service.update_record(context,
domain['id'],
expected_record['id'],
values,
increment_serial=False)
self.central_service.update_record(
self.admin_context, domain['id'], expected_record['id'],
values, increment_serial=False)
# Fetch the record again
record = self.central_service.get_record(context, domain['id'],
expected_record['id'])
record = self.central_service.get_record(
self.admin_context, domain['id'], expected_record['id'])
# Ensure the record was updated correctly
self.assertEqual(record['data'], '127.0.0.2')
# Ensure the domains serial number was not updated
domain_after = self.central_service.get_domain(context, domain['id'])
domain_after = self.central_service.get_domain(
self.admin_context, domain['id'])
self.assertEqual(domain_before['serial'], domain_after['serial'])
def test_update_record_incorrect_domain_id(self):
context = self.get_admin_context()
domain = self.create_domain()
other_domain = self.create_domain(fixture=1)
@ -1148,12 +1135,11 @@ class CentralServiceTest(CentralTestCase):
# Ensure we get a 404 if we use the incorrect domain_id
with testtools.ExpectedException(exceptions.RecordNotFound):
self.central_service.update_record(context, other_domain['id'],
expected_record['id'],
values=values)
self.central_service.update_record(
self.admin_context, other_domain['id'], expected_record['id'],
values=values)
def test_update_record_duplicate_ptr(self):
context = self.get_admin_context()
domain = self.create_domain(values={'name': '2.0.192.in-addr.arpa.'})
values = dict(
@ -1162,8 +1148,8 @@ class CentralServiceTest(CentralTestCase):
data='www.example.org.'
)
self.central_service.create_record(context, domain['id'],
values=values)
self.central_service.create_record(
self.admin_context, domain['id'], values=values)
values = dict(
name='2.%s' % domain['name'],
@ -1171,8 +1157,8 @@ class CentralServiceTest(CentralTestCase):
data='www.example.org.'
)
record = self.central_service.create_record(context, domain['id'],
values=values)
record = self.central_service.create_record(
self.admin_context, domain['id'], values=values)
# Attempt to create a second PTR with the same name.
with testtools.ExpectedException(exceptions.DuplicateRecord):
@ -1180,12 +1166,10 @@ class CentralServiceTest(CentralTestCase):
name='1.%s' % domain['name']
)
self.central_service.update_record(context, domain['id'],
record['id'],
values=values)
self.central_service.update_record(
self.admin_context, domain['id'], record['id'], values=values)
def test_update_record_cname_data(self):
context = self.get_admin_context()
domain = self.create_domain()
# Create a record
@ -1194,19 +1178,18 @@ class CentralServiceTest(CentralTestCase):
# Update the record
values = dict(data='example.com.')
self.central_service.update_record(context, domain['id'],
self.central_service.update_record(self.admin_context, domain['id'],
expected_record['id'],
values=values)
# Fetch the record again
record = self.central_service.get_record(context, domain['id'],
expected_record['id'])
record = self.central_service.get_record(
self.admin_context, domain['id'], expected_record['id'])
# Ensure the record was updated correctly
self.assertEqual(record['data'], 'example.com.')
def test_update_record_ptr_data(self):
context = self.get_admin_context()
domain = self.create_domain(name='2.0.192.in-addr.arpa.')
# Create a record
@ -1218,57 +1201,60 @@ class CentralServiceTest(CentralTestCase):
# Update the record
values = dict(data='example.com.')
self.central_service.update_record(context, domain['id'],
self.central_service.update_record(self.admin_context, domain['id'],
expected_record['id'],
values=values)
# Fetch the record again
record = self.central_service.get_record(context, domain['id'],
record = self.central_service.get_record(self.admin_context,
domain['id'],
expected_record['id'])
# Ensure the record was updated correctly
self.assertEqual(record['data'], 'example.com.')
def test_delete_record(self):
context = self.get_admin_context()
domain = self.create_domain()
# Create a record
record = self.create_record(domain)
# Delete the record
self.central_service.delete_record(context, domain['id'], record['id'])
self.central_service.delete_record(self.admin_context, domain['id'],
record['id'])
# Fetch the record again, ensuring an exception is raised
with testtools.ExpectedException(exceptions.RecordNotFound):
self.central_service.get_record(context, domain['id'],
self.central_service.get_record(self.admin_context, domain['id'],
record['id'])
def test_delete_record_without_incrementing_serial(self):
context = self.get_admin_context()
domain = self.create_domain()
# Create a record
record = self.create_record(domain)
# Fetch the domain so we have the latest serial number
domain_before = self.central_service.get_domain(context, domain['id'])
domain_before = self.central_service.get_domain(
self.admin_context, domain['id'])
# Delete the record
self.central_service.delete_record(context, domain['id'], record['id'],
increment_serial=False)
self.central_service.delete_record(
self.admin_context, domain['id'], record['id'],
increment_serial=False)
# Fetch the record again, ensuring an exception is raised
with testtools.ExpectedException(exceptions.RecordNotFound):
self.central_service.get_record(context, domain['id'],
record['id'])
self.central_service.get_record(
self.admin_context, domain['id'], record['id'])
# Ensure the domains serial number was not updated
domain_after = self.central_service.get_domain(context, domain['id'])
domain_after = self.central_service.get_domain(
self.admin_context, domain['id'])
self.assertEqual(domain_before['serial'], domain_after['serial'])
def test_delete_record_incorrect_domain_id(self):
context = self.get_admin_context()
domain = self.create_domain()
other_domain = self.create_domain(fixture=1)
@ -1277,8 +1263,8 @@ class CentralServiceTest(CentralTestCase):
# Ensure we get a 404 if we use the incorrect domain_id
with testtools.ExpectedException(exceptions.RecordNotFound):
self.central_service.delete_record(context, other_domain['id'],
record['id'])
self.central_service.delete_record(
self.admin_context, other_domain['id'], record['id'])
def test_count_records(self):
# in the beginning, there should be nothing

View File

@ -18,26 +18,6 @@ from designate import context
class TestDesignateContext(TestCase):
def test_sudo(self):
# Set the policy to accept the authz
self.policy({'use_sudo': '@'})
ctxt = context.DesignateContext(tenant='original')
ctxt.sudo('effective')
self.assertEqual('effective', ctxt.tenant_id)
self.assertEqual('original', ctxt.original_tenant_id)
def test_sudo_fail(self):
# Set the policy to deny the authz
self.policy({'use_sudo': '!'})
ctxt = context.DesignateContext(tenant='original')
ctxt.sudo('effective')
self.assertEqual('original', ctxt.tenant_id)
self.assertEqual('original', ctxt.original_tenant_id)
def test_deepcopy(self):
orig = context.DesignateContext(user='12345', tenant='54321')
copy = orig.deepcopy()

View File

@ -27,8 +27,10 @@ class StorageQuotaTest(tests.TestCase):
self.quota = quota.get_quota()
def test_set_quota_create(self):
quota = self.quota.set_quota(self.admin_context, 'tenant_id',
'domains', 1500)
context = self.get_admin_context()
context.all_tenants = True
quota = self.quota.set_quota(context, 'tenant_id', 'domains', 1500)
self.assertEqual(quota, {'domains': 1500})
@ -39,19 +41,21 @@ class StorageQuotaTest(tests.TestCase):
'resource': 'domains'
}
quota = self.quota.storage_api.find_quota(self.admin_context,
criterion)
quota = self.quota.storage_api.find_quota(context, criterion)
self.assertEqual(quota['tenant_id'], 'tenant_id')
self.assertEqual(quota['resource'], 'domains')
self.assertEqual(quota['hard_limit'], 1500)
def test_set_quota_update(self):
context = self.get_admin_context()
context.all_tenants = True
# First up, Create the quota
self.quota.set_quota(self.admin_context, 'tenant_id', 'domains', 1500)
self.quota.set_quota(context, 'tenant_id', 'domains', 1500)
# Next, update the quota
self.quota.set_quota(self.admin_context, 'tenant_id', 'domains', 1234)
self.quota.set_quota(context, 'tenant_id', 'domains', 1234)
# Drop into the storage layer directly to ensure the quota was updated
# sucessfully
@ -60,23 +64,24 @@ class StorageQuotaTest(tests.TestCase):
'resource': 'domains'
}
quota = self.quota.storage_api.find_quota(self.admin_context,
criterion)
quota = self.quota.storage_api.find_quota(context, criterion)
self.assertEqual(quota['tenant_id'], 'tenant_id')
self.assertEqual(quota['resource'], 'domains')
self.assertEqual(quota['hard_limit'], 1234)
def test_reset_quotas(self):
context = self.get_admin_context()
context.all_tenants = True
# First up, Create a domains quota
self.quota.set_quota(self.admin_context, 'tenant_id', 'domains', 1500)
self.quota.set_quota(context, 'tenant_id', 'domains', 1500)
# Then, Create a domain_records quota
self.quota.set_quota(self.admin_context, 'tenant_id', 'domain_records',
800)
self.quota.set_quota(context, 'tenant_id', 'domain_records', 800)
# Now, Reset the tenants quota
self.quota.reset_quotas(self.admin_context, 'tenant_id')
self.quota.reset_quotas(context, 'tenant_id')
# Drop into the storage layer directly to ensure the tenant has no
# specific quotas registed.
@ -84,7 +89,5 @@ class StorageQuotaTest(tests.TestCase):
'tenant_id': 'tenant_id'
}
quotas = self.quota.storage_api.find_quotas(self.admin_context,
criterion)
quotas = self.quota.storage_api.find_quotas(context, criterion)
self.assertEqual(0, len(quotas))

View File

@ -22,37 +22,55 @@ LOG = logging.getLogger(__name__)
class StorageTestCase(object):
def create_quota(self, fixture=0, values={}):
def create_quota(self, fixture=0, values={}, context=None):
if not context:
context = self.admin_context
fixture = self.get_quota_fixture(fixture, values)
return fixture, self.storage.create_quota(self.admin_context, fixture)
def create_server(self, fixture=0, values={}):
if 'tenant_id' not in fixture:
fixture['tenant_id'] = context.tenant_id
return fixture, self.storage.create_quota(context, fixture)
def create_server(self, fixture=0, values={}, context=None):
if not context:
context = self.admin_context
fixture = self.get_server_fixture(fixture, values)
return fixture, self.storage.create_server(self.admin_context, fixture)
return fixture, self.storage.create_server(context, fixture)
def create_tsigkey(self, fixture=0, values={}, context=None):
if not context:
context = self.admin_context
def create_tsigkey(self, fixture=0, values={}):
fixture = self.get_tsigkey_fixture(fixture, values)
return fixture, self.storage.create_tsigkey(self.admin_context,
fixture)
return fixture, self.storage.create_tsigkey(context, fixture)
def create_domain(self, fixture=0, values={}, context=None):
if not context:
context = self.admin_context
def create_domain(self, fixture=0, values={}):
fixture = self.get_domain_fixture(fixture, values)
if 'tenant_id' not in values:
fixture['tenant_id'] = self.admin_context.tenant_id
if 'tenant_id' not in fixture:
fixture['tenant_id'] = context.tenant_id
return fixture, self.storage.create_domain(self.admin_context,
fixture)
return fixture, self.storage.create_domain(context, fixture)
def create_record(self, domain, fixture=0, values={}, context=None):
if not context:
context = self.admin_context
def create_record(self, domain, fixture=0, values={}):
fixture = self.get_record_fixture(domain['name'], fixture, values)
return fixture, self.storage.create_record(self.admin_context,
return fixture, self.storage.create_record(context,
domain['id'],
fixture)
# Quota Tests
def test_create_quota(self):
values = self.get_quota_fixture()
values['tenant_id'] = self.admin_context.tenant_id
result = self.storage.create_quota(self.admin_context, values=values)
@ -60,7 +78,7 @@ class StorageTestCase(object):
self.assertIsNotNone(result['created_at'])
self.assertIsNone(result['updated_at'])
self.assertEqual(result['tenant_id'], values['tenant_id'])
self.assertEqual(result['tenant_id'], self.admin_context.tenant_id)
self.assertEqual(result['resource'], values['resource'])
self.assertEqual(result['hard_limit'], values['hard_limit'])
@ -187,15 +205,16 @@ class StorageTestCase(object):
self.assertEqual(updated['hard_limit'], fixture['hard_limit'])
def test_update_quota_duplicate(self):
# Create two quotas
self.create_quota(fixture=0)
_, quota = self.create_quota(fixture=1)
context = self.get_admin_context()
context.all_tenants = True
values = self.quota_fixtures[0]
# Create two quotas
self.create_quota(fixture=0, values={'tenant_id': '1'})
_, quota = self.create_quota(fixture=0, values={'tenant_id': '2'})
with testtools.ExpectedException(exceptions.DuplicateQuota):
self.storage.update_quota(self.admin_context, quota['id'],
values)
self.storage.update_quota(context, quota['id'],
values={'tenant_id': '1'})
def test_update_quota_missing(self):
with testtools.ExpectedException(exceptions.QuotaNotFound):
@ -460,16 +479,19 @@ class StorageTestCase(object):
# Tenant Tests
def test_find_tenants(self):
context = self.get_admin_context()
context.all_tenants = True
# create 3 domains in 2 tenants
self.create_domain(fixture=0, values={'tenant_id': 'One'})
_, domain = self.create_domain(fixture=1, values={'tenant_id': 'One'})
self.create_domain(fixture=2, values={'tenant_id': 'Two'})
# Delete one of the domains.
self.storage.delete_domain(self.admin_context, domain['id'])
self.storage.delete_domain(context, domain['id'])
# Ensure we get accurate results
result = self.storage.find_tenants(self.admin_context)
result = self.storage.find_tenants(context)
expected = [{
'id': 'One',
@ -482,15 +504,18 @@ class StorageTestCase(object):
self.assertEqual(result, expected)
def test_get_tenant(self):
context = self.get_admin_context()
context.all_tenants = True
# create 2 domains in a tenant
_, domain_1 = self.create_domain(fixture=0, values={'tenant_id': 1})
_, domain_2 = self.create_domain(fixture=1, values={'tenant_id': 1})
_, domain_3 = self.create_domain(fixture=2, values={'tenant_id': 1})
# Delete one of the domains.
self.storage.delete_domain(self.admin_context, domain_3['id'])
self.storage.delete_domain(context, domain_3['id'])
result = self.storage.get_tenant(self.admin_context, 1)
result = self.storage.get_tenant(context, 1)
self.assertEqual(result['id'], 1)
self.assertEqual(result['domain_count'], 2)
@ -498,8 +523,11 @@ class StorageTestCase(object):
[domain_1['name'], domain_2['name']])
def test_count_tenants(self):
context = self.get_admin_context()
context.all_tenants = True
# in the beginning, there should be nothing
tenants = self.storage.count_tenants(self.admin_context)
tenants = self.storage.count_tenants(context)
self.assertEqual(tenants, 0)
# create 2 domains with 2 tenants
@ -508,9 +536,9 @@ class StorageTestCase(object):
_, domain = self.create_domain(fixture=2, values={'tenant_id': 2})
# Delete one of the domains.
self.storage.delete_domain(self.admin_context, domain['id'])
self.storage.delete_domain(context, domain['id'])
tenants = self.storage.count_tenants(self.admin_context)
tenants = self.storage.count_tenants(context)
self.assertEqual(tenants, 2)
# Domain Tests
@ -586,6 +614,38 @@ class StorageTestCase(object):
self.assertEqual(results[0]['email'], domain_two['email'])
self.assertIn('status', domain_two)
def test_find_domains_all_tenants(self):
# Create two contexts with different tenant_id's
one_context = self.get_admin_context()
one_context.tenant = 1
two_context = self.get_admin_context()
two_context.tenant = 2
# Create normal and all_tenants context objects
nm_context = self.get_admin_context()
at_context = self.get_admin_context()
at_context.all_tenants = True
# Create two domains in different tenants
self.create_domain(fixture=0, context=one_context)
self.create_domain(fixture=1, context=two_context)
# Ensure the all_tenants context see's two domains
results = self.storage.find_domains(at_context)
self.assertEqual(len(results), 2)
# Ensure the normal context see's no domains
results = self.storage.find_domains(nm_context)
self.assertEqual(len(results), 0)
# Ensure the tenant 1 context see's 1 domain
results = self.storage.find_domains(one_context)
self.assertEqual(len(results), 1)
# Ensure the tenant 2 context see's 1 domain
results = self.storage.find_domains(two_context)
self.assertEqual(len(results), 1)
def test_get_domain(self):
# Create a domain
fixture, expected = self.create_domain()
@ -604,7 +664,7 @@ class StorageTestCase(object):
context = self.get_admin_context()
context.show_deleted = True
_, domain = self.create_domain()
_, domain = self.create_domain(context=context)
self.storage.delete_domain(context, domain['id'])
self.storage.get_domain(context, domain['id'])
@ -801,6 +861,41 @@ class StorageTestCase(object):
self.assertEqual(len(results), 1)
def test_find_records_all_tenants(self):
# Create two contexts with different tenant_id's
one_context = self.get_admin_context()
one_context.tenant = 1
two_context = self.get_admin_context()
two_context.tenant = 2
# Create normal and all_tenants context objects
nm_context = self.get_admin_context()
at_context = self.get_admin_context()
at_context.all_tenants = True
# Create two domains in different tenants, and 1 record in each
_, domain_one = self.create_domain(fixture=0, context=one_context)
self.create_record(domain_one, fixture=0, context=one_context)
_, domain_two = self.create_domain(fixture=1, context=two_context)
self.create_record(domain_two, fixture=0, context=two_context)
# Ensure the all_tenants context see's two records
results = self.storage.find_records(at_context)
self.assertEqual(len(results), 2)
# Ensure the normal context see's no records
results = self.storage.find_records(nm_context)
self.assertEqual(len(results), 0)
# Ensure the tenant 1 context see's 1 record
results = self.storage.find_records(one_context)
self.assertEqual(len(results), 1)
# Ensure the tenant 2 context see's 1 record
results = self.storage.find_records(two_context)
self.assertEqual(len(results), 1)
def test_get_record(self):
_, domain = self.create_domain()

View File

@ -1,6 +1,6 @@
{
"admin": "role:admin or is_admin:True",
"owner": "tenant_id:%(tenant_id)s or tenant_id:%(effective_tenant_id)s",
"owner": "tenant_id:%(tenant_id)s",
"admin_or_owner": "rule:admin or rule:owner",
"default": "rule:admin_or_owner",