Provide support for "All Tenants" access
This moves some of the per-tenant result filtering into the storage layer, a necessary pre-requisite to supporting both the v1 and v2 APIs for record access. Again, Additionally, This fixes a number of issues we had with our existing tests, where the handling of tenant_id's was inconsistent. Change-Id: I1a20a3c81edbf74a7d6256118d206e509c6da65c
This commit is contained in:
parent
0d23371c87
commit
aa780afb12
@ -23,7 +23,7 @@ from designate.context import DesignateContext
|
||||
from designate.openstack.common import jsonutils as json
|
||||
from designate.openstack.common import local
|
||||
from designate.openstack.common import log as logging
|
||||
from designate.openstack.common import uuidutils
|
||||
from designate.openstack.common import strutils
|
||||
from designate.openstack.common.rpc import common as rpc_common
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -102,10 +102,10 @@ class KeystoneContextMiddleware(ContextMiddleware):
|
||||
|
||||
try:
|
||||
if headers['X-Identity-Status'] is 'Invalid':
|
||||
#TODO(graham) fix the return to use non-flask resources
|
||||
# TODO(graham) fix the return to use non-flask resources
|
||||
return flask.Response(status=401)
|
||||
except KeyError:
|
||||
#If the key is valid, Keystone does not include this header at all
|
||||
# If the key is valid, Keystone does not include this header at all
|
||||
pass
|
||||
|
||||
if headers.get('X-Service-Catalog'):
|
||||
@ -124,13 +124,6 @@ class KeystoneContextMiddleware(ContextMiddleware):
|
||||
# Store the context where oslo-log exepcts to find it.
|
||||
local.store.context = context
|
||||
|
||||
# Attempt to sudo, if requested.
|
||||
sudo_tenant_id = headers.get('X-Designate-Sudo-Tenant-ID', None)
|
||||
|
||||
if sudo_tenant_id and (uuidutils.is_uuid_like(sudo_tenant_id)
|
||||
or sudo_tenant_id.isdigit()):
|
||||
context.sudo(sudo_tenant_id)
|
||||
|
||||
# Attach the context to the request environment
|
||||
request.environ['context'] = context
|
||||
|
||||
@ -153,6 +146,34 @@ class NoAuthContextMiddleware(ContextMiddleware):
|
||||
request.environ['context'] = context
|
||||
|
||||
|
||||
class TestContextMiddleware(ContextMiddleware):
|
||||
def __init__(self, application, tenant_id=None, user_id=None):
|
||||
super(TestContextMiddleware, self).__init__(application)
|
||||
|
||||
LOG.critical('Starting designate testcontext middleware')
|
||||
LOG.critical('**** DO NOT USE IN PRODUCTION ****')
|
||||
|
||||
self.default_tenant_id = tenant_id
|
||||
self.default_user_id = user_id
|
||||
|
||||
def process_request(self, request):
|
||||
headers = request.headers
|
||||
|
||||
all_tenants = strutils.bool_from_string(
|
||||
headers.get('X-Test-All-Tenants', 'False'))
|
||||
|
||||
context = DesignateContext(
|
||||
user=headers.get('X-Test-User-ID', self.default_user_id),
|
||||
tenant=headers.get('X-Test-Tenant-ID', self.default_tenant_id),
|
||||
all_tenants=all_tenants)
|
||||
|
||||
# Store the context where oslo-log exepcts to find it.
|
||||
local.store.context = context
|
||||
|
||||
# Attach the context to the request environment
|
||||
request.environ['context'] = context
|
||||
|
||||
|
||||
class FaultWrapperMiddleware(wsgi.Middleware):
|
||||
def __init__(self, application):
|
||||
super(FaultWrapperMiddleware, self).__init__(application)
|
||||
|
@ -183,6 +183,9 @@ class Service(rpc_service.Service):
|
||||
return False
|
||||
|
||||
def _is_subdomain(self, context, domain_name):
|
||||
context = context.elevated()
|
||||
context.all_tenants = True
|
||||
|
||||
# Break the name up into it's component labels
|
||||
labels = domain_name.split(".")
|
||||
|
||||
@ -403,7 +406,10 @@ class Service(rpc_service.Service):
|
||||
# Domain Methods
|
||||
def create_domain(self, context, values):
|
||||
# TODO(kiall): Refactor this method into *MUCH* smaller chunks.
|
||||
values['tenant_id'] = context.tenant_id
|
||||
|
||||
# Default to creating in the current users tenant
|
||||
if 'tenant_id' not in values:
|
||||
values['tenant_id'] = context.tenant_id
|
||||
|
||||
target = {
|
||||
'tenant_id': values['tenant_id'],
|
||||
@ -483,21 +489,12 @@ class Service(rpc_service.Service):
|
||||
target = {'tenant_id': context.tenant_id}
|
||||
policy.check('find_domains', context, target)
|
||||
|
||||
if criterion is None:
|
||||
criterion = {}
|
||||
|
||||
if not context.is_admin:
|
||||
criterion['tenant_id'] = context.tenant_id
|
||||
|
||||
return self.storage_api.find_domains(context, criterion)
|
||||
|
||||
def find_domain(self, context, criterion):
|
||||
target = {'tenant_id': context.tenant_id}
|
||||
policy.check('find_domain', context, target)
|
||||
|
||||
if not context.is_admin:
|
||||
criterion['tenant_id'] = context.tenant_id
|
||||
|
||||
return self.storage_api.find_domain(context, criterion)
|
||||
|
||||
def update_domain(self, context, domain_id, values, increment_serial=True):
|
||||
|
@ -16,7 +16,6 @@
|
||||
import itertools
|
||||
from designate.openstack.common import context
|
||||
from designate.openstack.common import log as logging
|
||||
from designate import policy
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -24,7 +23,7 @@ LOG = logging.getLogger(__name__)
|
||||
class DesignateContext(context.RequestContext):
|
||||
def __init__(self, auth_token=None, user=None, tenant=None, is_admin=False,
|
||||
read_only=False, show_deleted=False, request_id=None,
|
||||
original_tenant_id=None, roles=[], service_catalog=None):
|
||||
roles=[], service_catalog=None, all_tenants=False):
|
||||
super(DesignateContext, self).__init__(
|
||||
auth_token=auth_token,
|
||||
user=user,
|
||||
@ -34,29 +33,9 @@ class DesignateContext(context.RequestContext):
|
||||
show_deleted=show_deleted,
|
||||
request_id=request_id)
|
||||
|
||||
self._original_tenant_id = original_tenant_id
|
||||
self.roles = roles
|
||||
self.service_catalog = service_catalog
|
||||
|
||||
def sudo(self, tenant_id, force=False):
|
||||
if force:
|
||||
allowed_sudo = True
|
||||
else:
|
||||
# We use exc=None here since the context is built early in the
|
||||
# request lifecycle, outside of our ordinary error handling.
|
||||
# For now, we silently ignore failed sudo requests.
|
||||
target = {'tenant_id': tenant_id}
|
||||
allowed_sudo = policy.check('use_sudo', self, target, exc=None)
|
||||
|
||||
if allowed_sudo:
|
||||
LOG.warn('Accepted sudo from user_id %s for tenant_id %s'
|
||||
% (self.user_id, tenant_id))
|
||||
self.original_tenant_id = self.tenant_id
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
else:
|
||||
LOG.warn('Rejected sudo from user_id %s for tenant_id %s'
|
||||
% (self.user_id, tenant_id))
|
||||
self.all_tenants = all_tenants
|
||||
|
||||
def deepcopy(self):
|
||||
d = self.to_dict()
|
||||
@ -73,9 +52,9 @@ class DesignateContext(context.RequestContext):
|
||||
d.update({
|
||||
'user_id': self.user_id,
|
||||
'tenant_id': self.tenant_id,
|
||||
'original_tenant_id': self.original_tenant_id,
|
||||
'roles': self.roles,
|
||||
'service_catalog': self.service_catalog
|
||||
'service_catalog': self.service_catalog,
|
||||
'all_tenants': self.all_tenants,
|
||||
})
|
||||
|
||||
return d
|
||||
@ -113,17 +92,6 @@ class DesignateContext(context.RequestContext):
|
||||
def tenant_id(self, value):
|
||||
self.tenant = value
|
||||
|
||||
@property
|
||||
def original_tenant_id(self):
|
||||
if self._original_tenant_id:
|
||||
return self._original_tenant_id
|
||||
else:
|
||||
return self.tenant
|
||||
|
||||
@original_tenant_id.setter
|
||||
def original_tenant_id(self, value):
|
||||
self._original_tenant_id = value
|
||||
|
||||
@classmethod
|
||||
def get_admin_context(cls, **kwargs):
|
||||
# TODO(kiall): Remove Me
|
||||
|
@ -69,7 +69,7 @@ class Handler(Plugin):
|
||||
"""
|
||||
Return the domain for this context
|
||||
"""
|
||||
context = DesignateContext.get_admin_context()
|
||||
context = DesignateContext.get_admin_context(all_tenants=True)
|
||||
return central_api.get_domain(context, domain_id)
|
||||
|
||||
|
||||
@ -99,7 +99,7 @@ class BaseAddressHandler(Handler):
|
||||
LOG.debug('Event data: %s' % data)
|
||||
data['domain'] = domain['name']
|
||||
|
||||
context = DesignateContext.get_admin_context()
|
||||
context = DesignateContext.get_admin_context(all_tenants=True)
|
||||
|
||||
for addr in addresses:
|
||||
record_data = data.copy()
|
||||
@ -128,7 +128,7 @@ class BaseAddressHandler(Handler):
|
||||
|
||||
:param criterion: Criterion to search and destroy records
|
||||
"""
|
||||
context = DesignateContext.get_admin_context()
|
||||
context = DesignateContext.get_admin_context(all_tenants=True)
|
||||
|
||||
if managed:
|
||||
criterion.update({
|
||||
|
133
designate/openstack/common/strutils.py
Normal file
133
designate/openstack/common/strutils.py
Normal file
@ -0,0 +1,133 @@
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright 2011 OpenStack Foundation.
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
System-level utilities and helper functions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def int_from_bool_as_string(subject):
|
||||
"""
|
||||
Interpret a string as a boolean and return either 1 or 0.
|
||||
|
||||
Any string value in:
|
||||
|
||||
('True', 'true', 'On', 'on', '1')
|
||||
|
||||
is interpreted as a boolean True.
|
||||
|
||||
Useful for JSON-decoded stuff and config file parsing
|
||||
"""
|
||||
return bool_from_string(subject) and 1 or 0
|
||||
|
||||
|
||||
def bool_from_string(subject):
|
||||
"""
|
||||
Interpret a string as a boolean.
|
||||
|
||||
Any string value in:
|
||||
|
||||
('True', 'true', 'On', 'on', 'Yes', 'yes', '1')
|
||||
|
||||
is interpreted as a boolean True.
|
||||
|
||||
Useful for JSON-decoded stuff and config file parsing
|
||||
"""
|
||||
if isinstance(subject, bool):
|
||||
return subject
|
||||
if isinstance(subject, basestring):
|
||||
if subject.strip().lower() in ('true', 'on', 'yes', '1'):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def safe_decode(text, incoming=None, errors='strict'):
|
||||
"""
|
||||
Decodes incoming str using `incoming` if they're
|
||||
not already unicode.
|
||||
|
||||
:param incoming: Text's current encoding
|
||||
:param errors: Errors handling policy. See here for valid
|
||||
values http://docs.python.org/2/library/codecs.html
|
||||
:returns: text or a unicode `incoming` encoded
|
||||
representation of it.
|
||||
:raises TypeError: If text is not an isntance of basestring
|
||||
"""
|
||||
if not isinstance(text, basestring):
|
||||
raise TypeError("%s can't be decoded" % type(text))
|
||||
|
||||
if isinstance(text, unicode):
|
||||
return text
|
||||
|
||||
if not incoming:
|
||||
incoming = (sys.stdin.encoding or
|
||||
sys.getdefaultencoding())
|
||||
|
||||
try:
|
||||
return text.decode(incoming, errors)
|
||||
except UnicodeDecodeError:
|
||||
# Note(flaper87) If we get here, it means that
|
||||
# sys.stdin.encoding / sys.getdefaultencoding
|
||||
# didn't return a suitable encoding to decode
|
||||
# text. This happens mostly when global LANG
|
||||
# var is not set correctly and there's no
|
||||
# default encoding. In this case, most likely
|
||||
# python will use ASCII or ANSI encoders as
|
||||
# default encodings but they won't be capable
|
||||
# of decoding non-ASCII characters.
|
||||
#
|
||||
# Also, UTF-8 is being used since it's an ASCII
|
||||
# extension.
|
||||
return text.decode('utf-8', errors)
|
||||
|
||||
|
||||
def safe_encode(text, incoming=None,
|
||||
encoding='utf-8', errors='strict'):
|
||||
"""
|
||||
Encodes incoming str/unicode using `encoding`. If
|
||||
incoming is not specified, text is expected to
|
||||
be encoded with current python's default encoding.
|
||||
(`sys.getdefaultencoding`)
|
||||
|
||||
:param incoming: Text's current encoding
|
||||
:param encoding: Expected encoding for text (Default UTF-8)
|
||||
:param errors: Errors handling policy. See here for valid
|
||||
values http://docs.python.org/2/library/codecs.html
|
||||
:returns: text or a bytestring `encoding` encoded
|
||||
representation of it.
|
||||
:raises TypeError: If text is not an isntance of basestring
|
||||
"""
|
||||
if not isinstance(text, basestring):
|
||||
raise TypeError("%s can't be encoded" % type(text))
|
||||
|
||||
if not incoming:
|
||||
incoming = (sys.stdin.encoding or
|
||||
sys.getdefaultencoding())
|
||||
|
||||
if isinstance(text, unicode):
|
||||
return text.encode(encoding, errors)
|
||||
elif text and encoding != incoming:
|
||||
# Decode text before encoding it with `encoding`
|
||||
text = safe_decode(text, incoming, errors)
|
||||
return text.encode(encoding, errors)
|
||||
|
||||
return text
|
@ -40,6 +40,9 @@ class StorageQuota(Quota):
|
||||
return dict((q['resource'], q['hard_limit']) for q in quotas)
|
||||
|
||||
def get_quota(self, context, tenant_id, resource):
|
||||
context = context.deepcopy()
|
||||
context.all_tenants = True
|
||||
|
||||
quota = self.storage_api.find_quota(context, {
|
||||
'tenant_id': tenant_id,
|
||||
'resource': resource,
|
||||
@ -48,6 +51,9 @@ class StorageQuota(Quota):
|
||||
return {resource: quota['hard_limit']}
|
||||
|
||||
def set_quota(self, context, tenant_id, resource, hard_limit):
|
||||
context = context.deepcopy()
|
||||
context.all_tenants = True
|
||||
|
||||
def create_quota():
|
||||
values = {
|
||||
'tenant_id': tenant_id,
|
||||
@ -81,6 +87,9 @@ class StorageQuota(Quota):
|
||||
return {resource: hard_limit}
|
||||
|
||||
def reset_quotas(self, context, tenant_id):
|
||||
context = context.deepcopy()
|
||||
context.all_tenants = True
|
||||
|
||||
quotas = self.storage_api.find_quotas(context, {
|
||||
'tenant_id': tenant_id,
|
||||
})
|
||||
|
@ -65,6 +65,15 @@ class SQLAlchemyStorage(base.Storage):
|
||||
|
||||
return query
|
||||
|
||||
def _apply_tenant_criteria(self, context, model, query):
|
||||
if hasattr(model, 'tenant_id'):
|
||||
if context.all_tenants:
|
||||
LOG.debug('Including all tenants items in query results')
|
||||
else:
|
||||
query = query.filter(model.tenant_id == context.tenant_id)
|
||||
|
||||
return query
|
||||
|
||||
def _apply_deleted_criteria(self, context, model, query):
|
||||
if issubclass(model, SoftDeleteMixin):
|
||||
if context.show_deleted:
|
||||
@ -83,6 +92,7 @@ class SQLAlchemyStorage(base.Storage):
|
||||
# First up, create a query and apply the various filters
|
||||
query = self.session.query(model)
|
||||
query = self._apply_criterion(model, query, criterion)
|
||||
query = self._apply_tenant_criteria(context, model, query)
|
||||
query = self._apply_deleted_criteria(context, model, query)
|
||||
|
||||
if one:
|
||||
@ -255,6 +265,7 @@ class SQLAlchemyStorage(base.Storage):
|
||||
# returns an array of tenant_id & count of their domains
|
||||
query = self.session.query(models.Domain.tenant_id,
|
||||
func.count(models.Domain.id))
|
||||
query = self._apply_tenant_criteria(context, models.Domain, query)
|
||||
query = self._apply_deleted_criteria(context, models.Domain, query)
|
||||
query = query.group_by(models.Domain.tenant_id)
|
||||
|
||||
@ -263,8 +274,9 @@ class SQLAlchemyStorage(base.Storage):
|
||||
def get_tenant(self, context, tenant_id):
|
||||
# get list list & count of all domains owned by given tenant_id
|
||||
query = self.session.query(models.Domain.name)
|
||||
query = query.filter(models.Domain.tenant_id == tenant_id)
|
||||
query = self._apply_tenant_criteria(context, models.Domain, query)
|
||||
query = self._apply_deleted_criteria(context, models.Domain, query)
|
||||
query = query.filter(models.Domain.tenant_id == tenant_id)
|
||||
|
||||
result = query.all()
|
||||
|
||||
@ -278,6 +290,7 @@ class SQLAlchemyStorage(base.Storage):
|
||||
# tenants are the owner of domains, count the number of unique tenants
|
||||
# select count(distinct tenant_id) from domains
|
||||
query = self.session.query(distinct(models.Domain.tenant_id))
|
||||
query = self._apply_tenant_criteria(context, models.Domain, query)
|
||||
query = self._apply_deleted_criteria(context, models.Domain, query)
|
||||
|
||||
return query.count()
|
||||
@ -339,6 +352,7 @@ class SQLAlchemyStorage(base.Storage):
|
||||
def count_domains(self, context, criterion=None):
|
||||
query = self.session.query(models.Domain)
|
||||
query = self._apply_criterion(models.Domain, query, criterion)
|
||||
query = self._apply_tenant_criteria(context, models.Domain, query)
|
||||
query = self._apply_deleted_criteria(context, models.Domain, query)
|
||||
|
||||
return query.count()
|
||||
@ -410,6 +424,7 @@ class SQLAlchemyStorage(base.Storage):
|
||||
|
||||
def count_records(self, context, criterion=None):
|
||||
query = self.session.query(models.Record)
|
||||
query = self._apply_tenant_criteria(context, models.Record, query)
|
||||
query = self._apply_criterion(models.Record, query, criterion)
|
||||
return query.count()
|
||||
|
||||
|
@ -93,11 +93,9 @@ class PolicyFixture(fixtures.Fixture):
|
||||
|
||||
class TestCase(test.BaseTestCase):
|
||||
quota_fixtures = [{
|
||||
'tenant_id': '12345',
|
||||
'resource': 'domains',
|
||||
'hard_limit': 5,
|
||||
}, {
|
||||
'tenant_id': '12345',
|
||||
'resource': 'records',
|
||||
'hard_limit': 50,
|
||||
}]
|
||||
|
@ -131,38 +131,6 @@ class KeystoneContextMiddlewareTest(ApiTestCase):
|
||||
self.assertEqual('TenantID', context.tenant_id)
|
||||
self.assertEqual(['admin', 'Member'], context.roles)
|
||||
|
||||
def test_process_request_sudo(self):
|
||||
# Set the policy to accept the authz
|
||||
self.policy({'use_sudo': '@'})
|
||||
|
||||
app = middleware.KeystoneContextMiddleware({})
|
||||
|
||||
request = FakeRequest()
|
||||
|
||||
request.headers = {
|
||||
'X-Auth-Token': 'AuthToken',
|
||||
'X-User-ID': 'UserID',
|
||||
'X-Tenant-ID': 'TenantID',
|
||||
'X-Roles': 'admin,Member',
|
||||
'X-Designate-Sudo-Tenant-ID':
|
||||
'5a993bf8-d521-420a-81e1-192d9cc3d5a0'
|
||||
}
|
||||
|
||||
# Process the request
|
||||
app.process_request(request)
|
||||
|
||||
self.assertIn('context', request.environ)
|
||||
|
||||
context = request.environ['context']
|
||||
|
||||
self.assertFalse(context.is_admin)
|
||||
self.assertEqual('AuthToken', context.auth_token)
|
||||
self.assertEqual('UserID', context.user_id)
|
||||
self.assertEqual('TenantID', context.original_tenant_id)
|
||||
self.assertEqual('5a993bf8-d521-420a-81e1-192d9cc3d5a0',
|
||||
context.tenant_id)
|
||||
self.assertEqual(['admin', 'Member'], context.roles)
|
||||
|
||||
def test_process_request_invalid_keystone_token(self):
|
||||
app = middleware.KeystoneContextMiddleware({})
|
||||
|
||||
|
@ -37,9 +37,10 @@ class ApiV1Test(ApiTestCase):
|
||||
self.app.wsgi_app = middleware.FaultWrapperMiddleware(
|
||||
self.app.wsgi_app)
|
||||
|
||||
# Inject the NoAuth middleware
|
||||
self.app.wsgi_app = middleware.NoAuthContextMiddleware(
|
||||
self.app.wsgi_app)
|
||||
# Inject the TestAuth middleware
|
||||
self.app.wsgi_app = middleware.TestContextMiddleware(
|
||||
self.app.wsgi_app, self.admin_context.tenant_id,
|
||||
self.admin_context.user_id)
|
||||
|
||||
# Obtain a test client
|
||||
self.client = self.app.test_client()
|
||||
|
@ -25,6 +25,13 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApiV1ServersTest(ApiV1Test):
|
||||
def setUp(self):
|
||||
super(ApiV1ServersTest, self).setUp()
|
||||
|
||||
# All Server Checks should be performed as an admin, so..
|
||||
# Override to policy to make everyone an admin.
|
||||
self.policy({'admin': '@'})
|
||||
|
||||
def test_create_server(self):
|
||||
# Create a server
|
||||
fixture = self.get_server_fixture(0)
|
||||
|
@ -36,8 +36,10 @@ class ApiV2TestCase(ApiTestCase):
|
||||
# Inject the FaultWrapper middleware
|
||||
self.app = middleware.FaultWrapperMiddleware(self.app)
|
||||
|
||||
# Inject the NoAuth middleware
|
||||
self.app = middleware.NoAuthContextMiddleware(self.app)
|
||||
# Inject the TestContext middleware
|
||||
self.app = middleware.TestContextMiddleware(
|
||||
self.app, self.admin_context.tenant_id,
|
||||
self.admin_context.tenant_id)
|
||||
|
||||
# Obtain a test client
|
||||
self.client = TestApp(self.app)
|
||||
|
@ -295,18 +295,21 @@ class CentralServiceTest(CentralTestCase):
|
||||
|
||||
# Tenant Tests
|
||||
def test_count_tenants(self):
|
||||
context = self.get_admin_context()
|
||||
admin_context = self.get_admin_context()
|
||||
admin_context.all_tenants = True
|
||||
|
||||
tenant_one_context = self.get_context(tenant=1)
|
||||
tenant_two_context = self.get_context(tenant=2)
|
||||
|
||||
# in the beginning, there should be nothing
|
||||
tenants = self.central_service.count_tenants(self.admin_context)
|
||||
tenants = self.central_service.count_tenants(admin_context)
|
||||
self.assertEqual(tenants, 0)
|
||||
|
||||
# Explicitly set a tenant_id
|
||||
context.tenant_id = '1'
|
||||
self.create_domain(fixture=0, context=context)
|
||||
context.tenant_id = '2'
|
||||
self.create_domain(fixture=1, context=context)
|
||||
self.create_domain(fixture=0, context=tenant_one_context)
|
||||
self.create_domain(fixture=1, context=tenant_two_context)
|
||||
|
||||
tenants = self.central_service.count_tenants(self.admin_context)
|
||||
tenants = self.central_service.count_tenants(admin_context)
|
||||
self.assertEqual(tenants, 2)
|
||||
|
||||
def test_count_tenants_policy_check(self):
|
||||
@ -398,20 +401,16 @@ class CentralServiceTest(CentralTestCase):
|
||||
self.create_domain()
|
||||
|
||||
def test_create_subdomain(self):
|
||||
context = self.get_admin_context()
|
||||
|
||||
# Explicitly set a tenant_id
|
||||
context.tenant_id = '1'
|
||||
|
||||
# Create the Parent Domain using fixture 0
|
||||
parent_domain = self.create_domain(fixture=0, context=context)
|
||||
parent_domain = self.create_domain(fixture=0)
|
||||
|
||||
# Prepare values for the subdomain using fixture 1 as a base
|
||||
values = self.get_domain_fixture(1)
|
||||
values['name'] = 'www.%s' % parent_domain['name']
|
||||
|
||||
# Create the subdomain
|
||||
domain = self.central_service.create_domain(context, values=values)
|
||||
domain = self.central_service.create_domain(
|
||||
self.admin_context, values=values)
|
||||
|
||||
# Ensure all values have been set correctly
|
||||
self.assertIsNotNone(domain['id'])
|
||||
@ -446,18 +445,17 @@ class CentralServiceTest(CentralTestCase):
|
||||
# Set the policy to accept the authz
|
||||
self.policy({'use_blacklisted_domain': '@'})
|
||||
|
||||
# Create a server
|
||||
self.create_server()
|
||||
|
||||
context = self.get_admin_context()
|
||||
|
||||
values = dict(
|
||||
name='blacklisted.com.',
|
||||
email='info@blacklisted.com'
|
||||
)
|
||||
|
||||
# Create a server
|
||||
self.create_server()
|
||||
|
||||
# Create a domain
|
||||
domain = self.central_service.create_domain(context, values=values)
|
||||
domain = self.central_service.create_domain(
|
||||
self.admin_context, values=values)
|
||||
|
||||
# Ensure all values have been set correctly
|
||||
self.assertIsNotNone(domain['id'])
|
||||
@ -471,8 +469,6 @@ class CentralServiceTest(CentralTestCase):
|
||||
# Set the policy to reject the authz
|
||||
self.policy({'use_blacklisted_domain': '!'})
|
||||
|
||||
context = self.get_admin_context()
|
||||
|
||||
values = dict(
|
||||
name='blacklisted.com.',
|
||||
email='info@blacklisted.com'
|
||||
@ -480,7 +476,8 @@ class CentralServiceTest(CentralTestCase):
|
||||
|
||||
with testtools.ExpectedException(exceptions.InvalidDomainName):
|
||||
# Create a domain
|
||||
self.central_service.create_domain(context, values=values)
|
||||
self.central_service.create_domain(
|
||||
self.admin_context, values=values)
|
||||
|
||||
def _test_create_domain_fail(self, values, exception):
|
||||
self.config(accepted_tlds_file='tlds-alpha-by-domain.txt.sample',
|
||||
@ -494,10 +491,10 @@ class CentralServiceTest(CentralTestCase):
|
||||
self.central_service.effective_tld._load_accepted_tld_list()
|
||||
self.central_service.effective_tld._load_effective_tld_list()
|
||||
|
||||
context = self.get_admin_context()
|
||||
with testtools.ExpectedException(exception):
|
||||
# Create an invalid domain
|
||||
self.central_service.create_domain(context, values=values)
|
||||
self.central_service.create_domain(
|
||||
self.admin_context, values=values)
|
||||
|
||||
def test_create_domain_invalid_tld_fail(self):
|
||||
self.config(accepted_tlds_file='tlds-alpha-by-domain.txt.sample',
|
||||
@ -511,8 +508,6 @@ class CentralServiceTest(CentralTestCase):
|
||||
self.central_service.effective_tld._load_accepted_tld_list()
|
||||
self.central_service.effective_tld._load_effective_tld_list()
|
||||
|
||||
context = self.get_admin_context()
|
||||
|
||||
# Create a server
|
||||
self.create_server()
|
||||
|
||||
@ -522,7 +517,7 @@ class CentralServiceTest(CentralTestCase):
|
||||
)
|
||||
|
||||
# Create a valid domain
|
||||
self.central_service.create_domain(context, values=values)
|
||||
self.central_service.create_domain(self.admin_context, values=values)
|
||||
|
||||
values = dict(
|
||||
name='invalid.NeT1.',
|
||||
@ -531,7 +526,8 @@ class CentralServiceTest(CentralTestCase):
|
||||
|
||||
with testtools.ExpectedException(exceptions.InvalidTLD):
|
||||
# Create an invalid domain
|
||||
self.central_service.create_domain(context, values=values)
|
||||
self.central_service.create_domain(
|
||||
self.admin_context, values=values)
|
||||
|
||||
def test_create_domain_effective_tld_fail(self):
|
||||
values = dict(
|
||||
@ -539,8 +535,8 @@ class CentralServiceTest(CentralTestCase):
|
||||
email='info@invalid.com'
|
||||
)
|
||||
|
||||
self._test_create_domain_fail(values,
|
||||
exceptions.DomainIsSameAsAnEffectiveTLD)
|
||||
self._test_create_domain_fail(
|
||||
values, exceptions.DomainIsSameAsAnEffectiveTLD)
|
||||
|
||||
def test_idn_create_domain_effective_tld_fail(self):
|
||||
# Test creation of the effective TLD - brønnøysund.no
|
||||
@ -549,8 +545,8 @@ class CentralServiceTest(CentralTestCase):
|
||||
email='info@invalid.com'
|
||||
)
|
||||
|
||||
self._test_create_domain_fail(values,
|
||||
exceptions.DomainIsSameAsAnEffectiveTLD)
|
||||
self._test_create_domain_fail(
|
||||
values, exceptions.DomainIsSameAsAnEffectiveTLD)
|
||||
|
||||
def test_create_domain_re_effective_tld_fail(self):
|
||||
# co.uk is in the regular expression list for effective_tlds
|
||||
@ -559,21 +555,19 @@ class CentralServiceTest(CentralTestCase):
|
||||
email='info@invalid.com'
|
||||
)
|
||||
|
||||
self._test_create_domain_fail(values,
|
||||
exceptions.DomainIsSameAsAnEffectiveTLD)
|
||||
self._test_create_domain_fail(
|
||||
values, exceptions.DomainIsSameAsAnEffectiveTLD)
|
||||
|
||||
def test_find_domains(self):
|
||||
context = self.get_admin_context()
|
||||
|
||||
# Ensure we have no domains to start with.
|
||||
domains = self.central_service.find_domains(context)
|
||||
domains = self.central_service.find_domains(self.admin_context)
|
||||
self.assertEqual(len(domains), 0)
|
||||
|
||||
# Create a single domain (using default values)
|
||||
self.create_domain()
|
||||
|
||||
# Ensure we can retrieve the newly created domain
|
||||
domains = self.central_service.find_domains(context)
|
||||
domains = self.central_service.find_domains(self.admin_context)
|
||||
self.assertEqual(len(domains), 1)
|
||||
self.assertEqual(domains[0]['name'], 'example.com.')
|
||||
|
||||
@ -581,27 +575,30 @@ class CentralServiceTest(CentralTestCase):
|
||||
self.create_domain(name='example.net.')
|
||||
|
||||
# Ensure we can retrieve both domain
|
||||
domains = self.central_service.find_domains(context)
|
||||
domains = self.central_service.find_domains(self.admin_context)
|
||||
self.assertEqual(len(domains), 2)
|
||||
self.assertEqual(domains[0]['name'], 'example.com.')
|
||||
self.assertEqual(domains[1]['name'], 'example.net.')
|
||||
|
||||
def test_find_domains_criteria(self):
|
||||
context = self.get_admin_context()
|
||||
|
||||
# Create a domain
|
||||
domain_name = '%d.example.com.' % random.randint(10, 1000)
|
||||
expected_domain = self.create_domain(name=domain_name)
|
||||
|
||||
# Retrieve it, and ensure it's the same
|
||||
criterion = {'name': domain_name}
|
||||
domains = self.central_service.find_domains(context, criterion)
|
||||
|
||||
domains = self.central_service.find_domains(
|
||||
self.admin_context, criterion)
|
||||
|
||||
self.assertEqual(domains[0]['id'], expected_domain['id'])
|
||||
self.assertEqual(domains[0]['name'], expected_domain['name'])
|
||||
self.assertEqual(domains[0]['email'], expected_domain['email'])
|
||||
|
||||
def test_find_domains_tenant_restrictions(self):
|
||||
admin_context = self.get_admin_context()
|
||||
admin_context.all_tenants = True
|
||||
|
||||
tenant_one_context = self.get_context(tenant=1)
|
||||
tenant_two_context = self.get_context(tenant=2)
|
||||
|
||||
@ -610,65 +607,62 @@ class CentralServiceTest(CentralTestCase):
|
||||
self.assertEqual(len(domains), 0)
|
||||
|
||||
# Create a single domain (using default values)
|
||||
self.create_domain(context=tenant_one_context)
|
||||
domain = self.create_domain(context=tenant_one_context)
|
||||
|
||||
# Ensure admins can retrieve the newly created domain
|
||||
domains = self.central_service.find_domains(admin_context)
|
||||
self.assertEqual(len(domains), 1)
|
||||
self.assertEqual(domains[0]['name'], 'example.com.')
|
||||
self.assertEqual(domains[0]['name'], domain['name'])
|
||||
|
||||
# Ensure tenant=1 can retrieve the newly created domain
|
||||
domains = self.central_service.find_domains(tenant_one_context)
|
||||
self.assertEqual(len(domains), 1)
|
||||
self.assertEqual(domains[0]['name'], 'example.com.')
|
||||
self.assertEqual(domains[0]['name'], domain['name'])
|
||||
|
||||
# Ensure tenant=2 can NOT retrieve the newly created domain
|
||||
domains = self.central_service.find_domains(tenant_two_context)
|
||||
self.assertEqual(len(domains), 0)
|
||||
|
||||
def test_get_domain(self):
|
||||
context = self.get_admin_context()
|
||||
|
||||
# Create a domain
|
||||
domain_name = '%d.example.com.' % random.randint(10, 1000)
|
||||
expected_domain = self.create_domain(name=domain_name)
|
||||
|
||||
# Retrieve it, and ensure it's the same
|
||||
domain = self.central_service.get_domain(context,
|
||||
expected_domain['id'])
|
||||
domain = self.central_service.get_domain(
|
||||
self.admin_context, expected_domain['id'])
|
||||
|
||||
self.assertEqual(domain['id'], expected_domain['id'])
|
||||
self.assertEqual(domain['name'], expected_domain['name'])
|
||||
self.assertEqual(domain['email'], expected_domain['email'])
|
||||
|
||||
def test_get_domain_servers(self):
|
||||
context = self.get_admin_context()
|
||||
|
||||
# Create a domain
|
||||
domain = self.create_domain()
|
||||
|
||||
# Retrieve the servers list
|
||||
servers = self.central_service.get_domain_servers(context,
|
||||
domain['id'])
|
||||
servers = self.central_service.get_domain_servers(
|
||||
self.admin_context, domain['id'])
|
||||
|
||||
self.assertTrue(len(servers) > 0)
|
||||
|
||||
def test_find_domain(self):
|
||||
context = self.get_admin_context()
|
||||
|
||||
# Create a domain
|
||||
domain_name = '%d.example.com.' % random.randint(10, 1000)
|
||||
expected_domain = self.create_domain(name=domain_name)
|
||||
|
||||
# Retrieve it, and ensure it's the same
|
||||
criterion = {'name': domain_name}
|
||||
domain = self.central_service.find_domain(context, criterion)
|
||||
|
||||
domain = self.central_service.find_domain(
|
||||
self.admin_context, criterion)
|
||||
|
||||
self.assertEqual(domain['id'], expected_domain['id'])
|
||||
self.assertEqual(domain['name'], expected_domain['name'])
|
||||
self.assertEqual(domain['email'], expected_domain['email'])
|
||||
self.assertIn('status', domain)
|
||||
|
||||
def test_update_domain(self):
|
||||
context = self.get_admin_context()
|
||||
|
||||
# Create a domain
|
||||
expected_domain = self.create_domain()
|
||||
|
||||
@ -677,12 +671,13 @@ class CentralServiceTest(CentralTestCase):
|
||||
|
||||
# Update the domain
|
||||
values = dict(email='new@example.com')
|
||||
self.central_service.update_domain(context, expected_domain['id'],
|
||||
values=values)
|
||||
|
||||
self.central_service.update_domain(
|
||||
self.admin_context, expected_domain['id'], values=values)
|
||||
|
||||
# Fetch the domain again
|
||||
domain = self.central_service.get_domain(context,
|
||||
expected_domain['id'])
|
||||
domain = self.central_service.get_domain(
|
||||
self.admin_context, expected_domain['id'])
|
||||
|
||||
# Ensure the domain was updated correctly
|
||||
self.assertTrue(domain['serial'] > expected_domain['serial'])
|
||||
@ -706,40 +701,36 @@ class CentralServiceTest(CentralTestCase):
|
||||
self.assertEqual(payload['tenant_id'], domain['tenant_id'])
|
||||
|
||||
def test_update_domain_without_incrementing_serial(self):
|
||||
context = self.get_admin_context()
|
||||
|
||||
# Create a domain
|
||||
expected_domain = self.create_domain()
|
||||
|
||||
# Update the domain
|
||||
values = dict(email='new@example.com')
|
||||
self.central_service.update_domain(context, expected_domain['id'],
|
||||
values=values,
|
||||
increment_serial=False)
|
||||
|
||||
self.central_service.update_domain(
|
||||
self.admin_context, expected_domain['id'], values=values,
|
||||
increment_serial=False)
|
||||
|
||||
# Fetch the domain again
|
||||
domain = self.central_service.get_domain(context,
|
||||
expected_domain['id'])
|
||||
domain = self.central_service.get_domain(
|
||||
self.admin_context, expected_domain['id'])
|
||||
|
||||
# Ensure the domain was updated correctly
|
||||
self.assertEqual(domain['serial'], expected_domain['serial'])
|
||||
self.assertEqual(domain['email'], 'new@example.com')
|
||||
|
||||
def test_update_domain_name_fail(self):
|
||||
context = self.get_admin_context()
|
||||
|
||||
# Create a domain
|
||||
expected_domain = self.create_domain()
|
||||
|
||||
# Update the domain
|
||||
with testtools.ExpectedException(exceptions.BadRequest):
|
||||
values = dict(name='renamed-domain.com.')
|
||||
self.central_service.update_domain(context, expected_domain['id'],
|
||||
values=values)
|
||||
|
||||
self.central_service.update_domain(
|
||||
self.admin_context, expected_domain['id'], values=values)
|
||||
|
||||
def test_delete_domain(self):
|
||||
context = self.get_admin_context()
|
||||
|
||||
# Create a domain
|
||||
domain = self.create_domain()
|
||||
|
||||
@ -747,11 +738,11 @@ class CentralServiceTest(CentralTestCase):
|
||||
self.reset_notifications()
|
||||
|
||||
# Delete the domain
|
||||
self.central_service.delete_domain(context, domain['id'])
|
||||
self.central_service.delete_domain(self.admin_context, domain['id'])
|
||||
|
||||
# Fetch the domain again, ensuring an exception is raised
|
||||
with testtools.ExpectedException(exceptions.DomainNotFound):
|
||||
self.central_service.get_domain(context, domain['id'])
|
||||
self.central_service.get_domain(self.admin_context, domain['id'])
|
||||
|
||||
# Ensure we sent exactly 1 notification
|
||||
notifications = self.get_notifications()
|
||||
@ -771,8 +762,6 @@ class CentralServiceTest(CentralTestCase):
|
||||
self.assertEqual(payload['tenant_id'], domain['tenant_id'])
|
||||
|
||||
def test_delete_parent_domain(self):
|
||||
context = self.get_admin_context()
|
||||
|
||||
# Create the Parent Domain using fixture 0
|
||||
parent_domain = self.create_domain(fixture=0)
|
||||
|
||||
@ -781,7 +770,8 @@ class CentralServiceTest(CentralTestCase):
|
||||
|
||||
# Attempt to delete the parent domain
|
||||
with testtools.ExpectedException(exceptions.DomainHasSubdomain):
|
||||
self.central_service.delete_domain(context, parent_domain['id'])
|
||||
self.central_service.delete_domain(
|
||||
self.admin_context, parent_domain['id'])
|
||||
|
||||
def test_count_domains(self):
|
||||
# in the beginning, there should be nothing
|
||||
@ -805,24 +795,22 @@ class CentralServiceTest(CentralTestCase):
|
||||
self.central_service.count_domains(self.get_context())
|
||||
|
||||
def test_touch_domain(self):
|
||||
context = self.get_admin_context()
|
||||
|
||||
# Create a domain
|
||||
expected_domain = self.create_domain()
|
||||
|
||||
# Touch the domain
|
||||
self.central_service.touch_domain(context, expected_domain['id'])
|
||||
self.central_service.touch_domain(
|
||||
self.admin_context, expected_domain['id'])
|
||||
|
||||
# Fetch the domain again
|
||||
domain = self.central_service.get_domain(context,
|
||||
expected_domain['id'])
|
||||
domain = self.central_service.get_domain(
|
||||
self.admin_context, expected_domain['id'])
|
||||
|
||||
# Ensure the serial was incremented
|
||||
self.assertTrue(domain['serial'] > expected_domain['serial'])
|
||||
|
||||
# Record Tests
|
||||
def test_create_record(self):
|
||||
context = self.get_admin_context()
|
||||
domain = self.create_domain()
|
||||
|
||||
values = dict(
|
||||
@ -832,8 +820,8 @@ class CentralServiceTest(CentralTestCase):
|
||||
)
|
||||
|
||||
# Create a record
|
||||
record = self.central_service.create_record(context, domain['id'],
|
||||
values=values)
|
||||
record = self.central_service.create_record(
|
||||
self.admin_context, domain['id'], values=values)
|
||||
|
||||
# Ensure all values have been set correctly
|
||||
self.assertIsNotNone(record['id'])
|
||||
@ -854,7 +842,6 @@ class CentralServiceTest(CentralTestCase):
|
||||
self.create_record(domain)
|
||||
|
||||
def test_create_record_without_incrementing_serial(self):
|
||||
context = self.get_admin_context()
|
||||
domain = self.create_domain()
|
||||
|
||||
values = dict(
|
||||
@ -864,9 +851,9 @@ class CentralServiceTest(CentralTestCase):
|
||||
)
|
||||
|
||||
# Create a record
|
||||
record = self.central_service.create_record(context, domain['id'],
|
||||
values=values,
|
||||
increment_serial=False)
|
||||
record = self.central_service.create_record(
|
||||
self.admin_context, domain['id'], values=values,
|
||||
increment_serial=False)
|
||||
|
||||
# Ensure all values have been set correctly
|
||||
self.assertIsNotNone(record['id'])
|
||||
@ -876,11 +863,12 @@ class CentralServiceTest(CentralTestCase):
|
||||
self.assertEqual(record['data'], values['data'])
|
||||
|
||||
# Ensure the domains serial number was not updated
|
||||
updated_domain = self.central_service.get_domain(context, domain['id'])
|
||||
updated_domain = self.central_service.get_domain(
|
||||
self.admin_context, domain['id'])
|
||||
|
||||
self.assertEqual(domain['serial'], updated_domain['serial'])
|
||||
|
||||
def test_create_cname_record_at_apex(self):
|
||||
context = self.get_admin_context()
|
||||
domain = self.create_domain()
|
||||
|
||||
values = dict(
|
||||
@ -891,11 +879,10 @@ class CentralServiceTest(CentralTestCase):
|
||||
|
||||
# Attempt to create a CNAME record at the apex
|
||||
with testtools.ExpectedException(exceptions.InvalidRecordLocation):
|
||||
self.central_service.create_record(context, domain['id'],
|
||||
values=values)
|
||||
self.central_service.create_record(
|
||||
self.admin_context, domain['id'], values=values)
|
||||
|
||||
def test_create_cname_record_above_an_a_record(self):
|
||||
context = self.get_admin_context()
|
||||
domain = self.create_domain()
|
||||
|
||||
values = dict(
|
||||
@ -904,8 +891,8 @@ class CentralServiceTest(CentralTestCase):
|
||||
data='127.0.0.1'
|
||||
)
|
||||
|
||||
self.central_service.create_record(context, domain['id'],
|
||||
values=values)
|
||||
self.central_service.create_record(
|
||||
self.admin_context, domain['id'], values=values)
|
||||
|
||||
# Create a CNAME record alongside an A record
|
||||
values = dict(
|
||||
@ -914,13 +901,12 @@ class CentralServiceTest(CentralTestCase):
|
||||
data='example.org.'
|
||||
)
|
||||
|
||||
record = self.central_service.create_record(context, domain['id'],
|
||||
values=values)
|
||||
record = self.central_service.create_record(
|
||||
self.admin_context, domain['id'], values=values)
|
||||
|
||||
self.assertIn('id', record)
|
||||
|
||||
def test_create_cname_record_below_an_a_record(self):
|
||||
context = self.get_admin_context()
|
||||
domain = self.create_domain()
|
||||
|
||||
values = dict(
|
||||
@ -929,8 +915,8 @@ class CentralServiceTest(CentralTestCase):
|
||||
data='127.0.0.1'
|
||||
)
|
||||
|
||||
self.central_service.create_record(context, domain['id'],
|
||||
values=values)
|
||||
self.central_service.create_record(
|
||||
self.admin_context, domain['id'], values=values)
|
||||
|
||||
# Create a CNAME record alongside an A record
|
||||
values = dict(
|
||||
@ -939,13 +925,12 @@ class CentralServiceTest(CentralTestCase):
|
||||
data='example.org.'
|
||||
)
|
||||
|
||||
record = self.central_service.create_record(context, domain['id'],
|
||||
values=values)
|
||||
record = self.central_service.create_record(
|
||||
self.admin_context, domain['id'], values=values)
|
||||
|
||||
self.assertIn('id', record)
|
||||
|
||||
def test_create_cname_record_alongside_an_a_record(self):
|
||||
context = self.get_admin_context()
|
||||
domain = self.create_domain()
|
||||
|
||||
values = dict(
|
||||
@ -954,8 +939,8 @@ class CentralServiceTest(CentralTestCase):
|
||||
data='127.0.0.1'
|
||||
)
|
||||
|
||||
self.central_service.create_record(context, domain['id'],
|
||||
values=values)
|
||||
self.central_service.create_record(
|
||||
self.admin_context, domain['id'], values=values)
|
||||
|
||||
# Attempt to create a CNAME record alongside an A record
|
||||
with testtools.ExpectedException(exceptions.InvalidRecordLocation):
|
||||
@ -965,11 +950,10 @@ class CentralServiceTest(CentralTestCase):
|
||||
data='example.org.'
|
||||
)
|
||||
|
||||
self.central_service.create_record(context, domain['id'],
|
||||
values=values)
|
||||
self.central_service.create_record(
|
||||
self.admin_context, domain['id'], values=values)
|
||||
|
||||
def test_create_an_a_record_alongside_a_cname_record(self):
|
||||
context = self.get_admin_context()
|
||||
domain = self.create_domain()
|
||||
|
||||
values = dict(
|
||||
@ -978,8 +962,8 @@ class CentralServiceTest(CentralTestCase):
|
||||
data='example.org.'
|
||||
)
|
||||
|
||||
self.central_service.create_record(context, domain['id'],
|
||||
values=values)
|
||||
self.central_service.create_record(
|
||||
self.admin_context, domain['id'], values=values)
|
||||
|
||||
# Attempt to create a CNAME record alongside an A record
|
||||
with testtools.ExpectedException(exceptions.InvalidRecordLocation):
|
||||
@ -989,11 +973,10 @@ class CentralServiceTest(CentralTestCase):
|
||||
data='127.0.0.1'
|
||||
)
|
||||
|
||||
self.central_service.create_record(context, domain['id'],
|
||||
values=values)
|
||||
self.central_service.create_record(
|
||||
self.admin_context, domain['id'], values=values)
|
||||
|
||||
def test_create_duplicate_ptr_record(self):
|
||||
context = self.get_admin_context()
|
||||
domain = self.create_domain(values={'name': '2.0.192.in-addr.arpa.'})
|
||||
|
||||
values = dict(
|
||||
@ -1002,8 +985,8 @@ class CentralServiceTest(CentralTestCase):
|
||||
data='www.example.org.'
|
||||
)
|
||||
|
||||
self.central_service.create_record(context, domain['id'],
|
||||
values=values)
|
||||
self.central_service.create_record(
|
||||
self.admin_context, domain['id'], values=values)
|
||||
|
||||
# Attempt to create a second PTR with the same name.
|
||||
with testtools.ExpectedException(exceptions.DuplicateRecord):
|
||||
@ -1013,22 +996,25 @@ class CentralServiceTest(CentralTestCase):
|
||||
data='www.example.com.'
|
||||
)
|
||||
|
||||
self.central_service.create_record(context, domain['id'],
|
||||
values=values)
|
||||
self.central_service.create_record(
|
||||
self.admin_context, domain['id'], values=values)
|
||||
|
||||
def test_find_records(self):
|
||||
context = self.get_admin_context()
|
||||
domain = self.create_domain()
|
||||
|
||||
# Ensure we have no records to start with.
|
||||
records = self.central_service.find_records(context, domain['id'])
|
||||
records = self.central_service.find_records(
|
||||
self.admin_context, domain['id'])
|
||||
|
||||
self.assertEqual(len(records), 0)
|
||||
|
||||
# Create a single record (using default values)
|
||||
self.create_record(domain)
|
||||
|
||||
# Ensure we can retrieve the newly created record
|
||||
records = self.central_service.find_records(context, domain['id'])
|
||||
records = self.central_service.find_records(
|
||||
self.admin_context, domain['id'])
|
||||
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertEqual(records[0]['name'], 'www.%s' % domain['name'])
|
||||
|
||||
@ -1036,13 +1022,14 @@ class CentralServiceTest(CentralTestCase):
|
||||
self.create_record(domain, name='mail.%s' % domain['name'])
|
||||
|
||||
# Ensure we can retrieve both records
|
||||
records = self.central_service.find_records(context, domain['id'])
|
||||
records = self.central_service.find_records(
|
||||
self.admin_context, domain['id'])
|
||||
|
||||
self.assertEqual(len(records), 2)
|
||||
self.assertEqual(records[0]['name'], 'www.%s' % domain['name'])
|
||||
self.assertEqual(records[1]['name'], 'mail.%s' % domain['name'])
|
||||
|
||||
def test_get_record(self):
|
||||
context = self.get_admin_context()
|
||||
domain = self.create_domain()
|
||||
|
||||
# Create a record
|
||||
@ -1050,14 +1037,14 @@ class CentralServiceTest(CentralTestCase):
|
||||
expected_record = self.create_record(domain, name=record_name)
|
||||
|
||||
# Retrieve it, and ensure it's the same
|
||||
record = self.central_service.get_record(context, domain['id'],
|
||||
expected_record['id'])
|
||||
record = self.central_service.get_record(
|
||||
self.admin_context, domain['id'], expected_record['id'])
|
||||
|
||||
self.assertEqual(record['id'], expected_record['id'])
|
||||
self.assertEqual(record['name'], expected_record['name'])
|
||||
self.assertIn('status', record)
|
||||
|
||||
def test_find_record(self):
|
||||
context = self.get_admin_context()
|
||||
domain = self.create_domain()
|
||||
|
||||
# Create a record
|
||||
@ -1066,14 +1053,15 @@ class CentralServiceTest(CentralTestCase):
|
||||
|
||||
# Retrieve it, and ensure it's the same
|
||||
criterion = {'name': record_name}
|
||||
record = self.central_service.find_record(context, domain['id'],
|
||||
criterion)
|
||||
|
||||
record = self.central_service.find_record(
|
||||
self.admin_context, domain['id'], criterion)
|
||||
|
||||
self.assertEqual(record['id'], expected_record['id'])
|
||||
self.assertEqual(record['name'], expected_record['name'])
|
||||
self.assertIn('status', record)
|
||||
|
||||
def test_get_record_incorrect_domain_id(self):
|
||||
context = self.get_admin_context()
|
||||
domain = self.create_domain()
|
||||
other_domain = self.create_domain(fixture=1)
|
||||
|
||||
@ -1083,11 +1071,10 @@ class CentralServiceTest(CentralTestCase):
|
||||
|
||||
# Ensure we get a 404 if we use the incorrect domain_id
|
||||
with testtools.ExpectedException(exceptions.RecordNotFound):
|
||||
self.central_service.get_record(context, other_domain['id'],
|
||||
expected_record['id'])
|
||||
self.central_service.get_record(
|
||||
self.admin_context, other_domain['id'], expected_record['id'])
|
||||
|
||||
def test_update_record(self):
|
||||
context = self.get_admin_context()
|
||||
domain = self.create_domain()
|
||||
|
||||
# Create a record
|
||||
@ -1095,48 +1082,48 @@ class CentralServiceTest(CentralTestCase):
|
||||
|
||||
# Update the record
|
||||
values = dict(data='127.0.0.2')
|
||||
self.central_service.update_record(context, domain['id'],
|
||||
expected_record['id'],
|
||||
values=values)
|
||||
|
||||
self.central_service.update_record(
|
||||
self.admin_context, domain['id'], expected_record['id'],
|
||||
values=values)
|
||||
|
||||
# Fetch the record again
|
||||
record = self.central_service.get_record(context, domain['id'],
|
||||
expected_record['id'])
|
||||
record = self.central_service.get_record(
|
||||
self.admin_context, domain['id'], expected_record['id'])
|
||||
|
||||
# Ensure the record was updated correctly
|
||||
self.assertEqual(record['data'], '127.0.0.2')
|
||||
|
||||
def test_update_record_without_incrementing_serial(self):
|
||||
context = self.get_admin_context()
|
||||
domain = self.create_domain()
|
||||
|
||||
# Create a record
|
||||
expected_record = self.create_record(domain)
|
||||
|
||||
# Fetch the domain so we have the latest serial number
|
||||
domain_before = self.central_service.get_domain(context, domain['id'])
|
||||
domain_before = self.central_service.get_domain(
|
||||
self.admin_context, domain['id'])
|
||||
|
||||
# Update the record
|
||||
values = dict(data='127.0.0.2')
|
||||
self.central_service.update_record(context,
|
||||
domain['id'],
|
||||
expected_record['id'],
|
||||
values,
|
||||
increment_serial=False)
|
||||
self.central_service.update_record(
|
||||
self.admin_context, domain['id'], expected_record['id'],
|
||||
values, increment_serial=False)
|
||||
|
||||
# Fetch the record again
|
||||
record = self.central_service.get_record(context, domain['id'],
|
||||
expected_record['id'])
|
||||
record = self.central_service.get_record(
|
||||
self.admin_context, domain['id'], expected_record['id'])
|
||||
|
||||
# Ensure the record was updated correctly
|
||||
self.assertEqual(record['data'], '127.0.0.2')
|
||||
|
||||
# Ensure the domains serial number was not updated
|
||||
domain_after = self.central_service.get_domain(context, domain['id'])
|
||||
domain_after = self.central_service.get_domain(
|
||||
self.admin_context, domain['id'])
|
||||
|
||||
self.assertEqual(domain_before['serial'], domain_after['serial'])
|
||||
|
||||
def test_update_record_incorrect_domain_id(self):
|
||||
context = self.get_admin_context()
|
||||
domain = self.create_domain()
|
||||
other_domain = self.create_domain(fixture=1)
|
||||
|
||||
@ -1148,12 +1135,11 @@ class CentralServiceTest(CentralTestCase):
|
||||
|
||||
# Ensure we get a 404 if we use the incorrect domain_id
|
||||
with testtools.ExpectedException(exceptions.RecordNotFound):
|
||||
self.central_service.update_record(context, other_domain['id'],
|
||||
expected_record['id'],
|
||||
values=values)
|
||||
self.central_service.update_record(
|
||||
self.admin_context, other_domain['id'], expected_record['id'],
|
||||
values=values)
|
||||
|
||||
def test_update_record_duplicate_ptr(self):
|
||||
context = self.get_admin_context()
|
||||
domain = self.create_domain(values={'name': '2.0.192.in-addr.arpa.'})
|
||||
|
||||
values = dict(
|
||||
@ -1162,8 +1148,8 @@ class CentralServiceTest(CentralTestCase):
|
||||
data='www.example.org.'
|
||||
)
|
||||
|
||||
self.central_service.create_record(context, domain['id'],
|
||||
values=values)
|
||||
self.central_service.create_record(
|
||||
self.admin_context, domain['id'], values=values)
|
||||
|
||||
values = dict(
|
||||
name='2.%s' % domain['name'],
|
||||
@ -1171,8 +1157,8 @@ class CentralServiceTest(CentralTestCase):
|
||||
data='www.example.org.'
|
||||
)
|
||||
|
||||
record = self.central_service.create_record(context, domain['id'],
|
||||
values=values)
|
||||
record = self.central_service.create_record(
|
||||
self.admin_context, domain['id'], values=values)
|
||||
|
||||
# Attempt to create a second PTR with the same name.
|
||||
with testtools.ExpectedException(exceptions.DuplicateRecord):
|
||||
@ -1180,12 +1166,10 @@ class CentralServiceTest(CentralTestCase):
|
||||
name='1.%s' % domain['name']
|
||||
)
|
||||
|
||||
self.central_service.update_record(context, domain['id'],
|
||||
record['id'],
|
||||
values=values)
|
||||
self.central_service.update_record(
|
||||
self.admin_context, domain['id'], record['id'], values=values)
|
||||
|
||||
def test_update_record_cname_data(self):
|
||||
context = self.get_admin_context()
|
||||
domain = self.create_domain()
|
||||
|
||||
# Create a record
|
||||
@ -1194,19 +1178,18 @@ class CentralServiceTest(CentralTestCase):
|
||||
|
||||
# Update the record
|
||||
values = dict(data='example.com.')
|
||||
self.central_service.update_record(context, domain['id'],
|
||||
self.central_service.update_record(self.admin_context, domain['id'],
|
||||
expected_record['id'],
|
||||
values=values)
|
||||
|
||||
# Fetch the record again
|
||||
record = self.central_service.get_record(context, domain['id'],
|
||||
expected_record['id'])
|
||||
record = self.central_service.get_record(
|
||||
self.admin_context, domain['id'], expected_record['id'])
|
||||
|
||||
# Ensure the record was updated correctly
|
||||
self.assertEqual(record['data'], 'example.com.')
|
||||
|
||||
def test_update_record_ptr_data(self):
|
||||
context = self.get_admin_context()
|
||||
domain = self.create_domain(name='2.0.192.in-addr.arpa.')
|
||||
|
||||
# Create a record
|
||||
@ -1218,57 +1201,60 @@ class CentralServiceTest(CentralTestCase):
|
||||
|
||||
# Update the record
|
||||
values = dict(data='example.com.')
|
||||
self.central_service.update_record(context, domain['id'],
|
||||
self.central_service.update_record(self.admin_context, domain['id'],
|
||||
expected_record['id'],
|
||||
values=values)
|
||||
|
||||
# Fetch the record again
|
||||
record = self.central_service.get_record(context, domain['id'],
|
||||
record = self.central_service.get_record(self.admin_context,
|
||||
domain['id'],
|
||||
expected_record['id'])
|
||||
|
||||
# Ensure the record was updated correctly
|
||||
self.assertEqual(record['data'], 'example.com.')
|
||||
|
||||
def test_delete_record(self):
|
||||
context = self.get_admin_context()
|
||||
domain = self.create_domain()
|
||||
|
||||
# Create a record
|
||||
record = self.create_record(domain)
|
||||
|
||||
# Delete the record
|
||||
self.central_service.delete_record(context, domain['id'], record['id'])
|
||||
self.central_service.delete_record(self.admin_context, domain['id'],
|
||||
record['id'])
|
||||
|
||||
# Fetch the record again, ensuring an exception is raised
|
||||
with testtools.ExpectedException(exceptions.RecordNotFound):
|
||||
self.central_service.get_record(context, domain['id'],
|
||||
self.central_service.get_record(self.admin_context, domain['id'],
|
||||
record['id'])
|
||||
|
||||
def test_delete_record_without_incrementing_serial(self):
|
||||
context = self.get_admin_context()
|
||||
domain = self.create_domain()
|
||||
|
||||
# Create a record
|
||||
record = self.create_record(domain)
|
||||
|
||||
# Fetch the domain so we have the latest serial number
|
||||
domain_before = self.central_service.get_domain(context, domain['id'])
|
||||
domain_before = self.central_service.get_domain(
|
||||
self.admin_context, domain['id'])
|
||||
|
||||
# Delete the record
|
||||
self.central_service.delete_record(context, domain['id'], record['id'],
|
||||
increment_serial=False)
|
||||
self.central_service.delete_record(
|
||||
self.admin_context, domain['id'], record['id'],
|
||||
increment_serial=False)
|
||||
|
||||
# Fetch the record again, ensuring an exception is raised
|
||||
with testtools.ExpectedException(exceptions.RecordNotFound):
|
||||
self.central_service.get_record(context, domain['id'],
|
||||
record['id'])
|
||||
self.central_service.get_record(
|
||||
self.admin_context, domain['id'], record['id'])
|
||||
|
||||
# Ensure the domains serial number was not updated
|
||||
domain_after = self.central_service.get_domain(context, domain['id'])
|
||||
domain_after = self.central_service.get_domain(
|
||||
self.admin_context, domain['id'])
|
||||
|
||||
self.assertEqual(domain_before['serial'], domain_after['serial'])
|
||||
|
||||
def test_delete_record_incorrect_domain_id(self):
|
||||
context = self.get_admin_context()
|
||||
domain = self.create_domain()
|
||||
other_domain = self.create_domain(fixture=1)
|
||||
|
||||
@ -1277,8 +1263,8 @@ class CentralServiceTest(CentralTestCase):
|
||||
|
||||
# Ensure we get a 404 if we use the incorrect domain_id
|
||||
with testtools.ExpectedException(exceptions.RecordNotFound):
|
||||
self.central_service.delete_record(context, other_domain['id'],
|
||||
record['id'])
|
||||
self.central_service.delete_record(
|
||||
self.admin_context, other_domain['id'], record['id'])
|
||||
|
||||
def test_count_records(self):
|
||||
# in the beginning, there should be nothing
|
||||
|
@ -18,26 +18,6 @@ from designate import context
|
||||
|
||||
|
||||
class TestDesignateContext(TestCase):
|
||||
def test_sudo(self):
|
||||
# Set the policy to accept the authz
|
||||
self.policy({'use_sudo': '@'})
|
||||
|
||||
ctxt = context.DesignateContext(tenant='original')
|
||||
ctxt.sudo('effective')
|
||||
|
||||
self.assertEqual('effective', ctxt.tenant_id)
|
||||
self.assertEqual('original', ctxt.original_tenant_id)
|
||||
|
||||
def test_sudo_fail(self):
|
||||
# Set the policy to deny the authz
|
||||
self.policy({'use_sudo': '!'})
|
||||
|
||||
ctxt = context.DesignateContext(tenant='original')
|
||||
ctxt.sudo('effective')
|
||||
|
||||
self.assertEqual('original', ctxt.tenant_id)
|
||||
self.assertEqual('original', ctxt.original_tenant_id)
|
||||
|
||||
def test_deepcopy(self):
|
||||
orig = context.DesignateContext(user='12345', tenant='54321')
|
||||
copy = orig.deepcopy()
|
||||
|
@ -27,8 +27,10 @@ class StorageQuotaTest(tests.TestCase):
|
||||
self.quota = quota.get_quota()
|
||||
|
||||
def test_set_quota_create(self):
|
||||
quota = self.quota.set_quota(self.admin_context, 'tenant_id',
|
||||
'domains', 1500)
|
||||
context = self.get_admin_context()
|
||||
context.all_tenants = True
|
||||
|
||||
quota = self.quota.set_quota(context, 'tenant_id', 'domains', 1500)
|
||||
|
||||
self.assertEqual(quota, {'domains': 1500})
|
||||
|
||||
@ -39,19 +41,21 @@ class StorageQuotaTest(tests.TestCase):
|
||||
'resource': 'domains'
|
||||
}
|
||||
|
||||
quota = self.quota.storage_api.find_quota(self.admin_context,
|
||||
criterion)
|
||||
quota = self.quota.storage_api.find_quota(context, criterion)
|
||||
|
||||
self.assertEqual(quota['tenant_id'], 'tenant_id')
|
||||
self.assertEqual(quota['resource'], 'domains')
|
||||
self.assertEqual(quota['hard_limit'], 1500)
|
||||
|
||||
def test_set_quota_update(self):
|
||||
context = self.get_admin_context()
|
||||
context.all_tenants = True
|
||||
|
||||
# First up, Create the quota
|
||||
self.quota.set_quota(self.admin_context, 'tenant_id', 'domains', 1500)
|
||||
self.quota.set_quota(context, 'tenant_id', 'domains', 1500)
|
||||
|
||||
# Next, update the quota
|
||||
self.quota.set_quota(self.admin_context, 'tenant_id', 'domains', 1234)
|
||||
self.quota.set_quota(context, 'tenant_id', 'domains', 1234)
|
||||
|
||||
# Drop into the storage layer directly to ensure the quota was updated
|
||||
# sucessfully
|
||||
@ -60,23 +64,24 @@ class StorageQuotaTest(tests.TestCase):
|
||||
'resource': 'domains'
|
||||
}
|
||||
|
||||
quota = self.quota.storage_api.find_quota(self.admin_context,
|
||||
criterion)
|
||||
quota = self.quota.storage_api.find_quota(context, criterion)
|
||||
|
||||
self.assertEqual(quota['tenant_id'], 'tenant_id')
|
||||
self.assertEqual(quota['resource'], 'domains')
|
||||
self.assertEqual(quota['hard_limit'], 1234)
|
||||
|
||||
def test_reset_quotas(self):
|
||||
context = self.get_admin_context()
|
||||
context.all_tenants = True
|
||||
|
||||
# First up, Create a domains quota
|
||||
self.quota.set_quota(self.admin_context, 'tenant_id', 'domains', 1500)
|
||||
self.quota.set_quota(context, 'tenant_id', 'domains', 1500)
|
||||
|
||||
# Then, Create a domain_records quota
|
||||
self.quota.set_quota(self.admin_context, 'tenant_id', 'domain_records',
|
||||
800)
|
||||
self.quota.set_quota(context, 'tenant_id', 'domain_records', 800)
|
||||
|
||||
# Now, Reset the tenants quota
|
||||
self.quota.reset_quotas(self.admin_context, 'tenant_id')
|
||||
self.quota.reset_quotas(context, 'tenant_id')
|
||||
|
||||
# Drop into the storage layer directly to ensure the tenant has no
|
||||
# specific quotas registed.
|
||||
@ -84,7 +89,5 @@ class StorageQuotaTest(tests.TestCase):
|
||||
'tenant_id': 'tenant_id'
|
||||
}
|
||||
|
||||
quotas = self.quota.storage_api.find_quotas(self.admin_context,
|
||||
criterion)
|
||||
|
||||
quotas = self.quota.storage_api.find_quotas(context, criterion)
|
||||
self.assertEqual(0, len(quotas))
|
||||
|
@ -22,37 +22,55 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StorageTestCase(object):
|
||||
def create_quota(self, fixture=0, values={}):
|
||||
def create_quota(self, fixture=0, values={}, context=None):
|
||||
if not context:
|
||||
context = self.admin_context
|
||||
|
||||
fixture = self.get_quota_fixture(fixture, values)
|
||||
return fixture, self.storage.create_quota(self.admin_context, fixture)
|
||||
|
||||
def create_server(self, fixture=0, values={}):
|
||||
if 'tenant_id' not in fixture:
|
||||
fixture['tenant_id'] = context.tenant_id
|
||||
|
||||
return fixture, self.storage.create_quota(context, fixture)
|
||||
|
||||
def create_server(self, fixture=0, values={}, context=None):
|
||||
if not context:
|
||||
context = self.admin_context
|
||||
|
||||
fixture = self.get_server_fixture(fixture, values)
|
||||
return fixture, self.storage.create_server(self.admin_context, fixture)
|
||||
return fixture, self.storage.create_server(context, fixture)
|
||||
|
||||
def create_tsigkey(self, fixture=0, values={}, context=None):
|
||||
if not context:
|
||||
context = self.admin_context
|
||||
|
||||
def create_tsigkey(self, fixture=0, values={}):
|
||||
fixture = self.get_tsigkey_fixture(fixture, values)
|
||||
return fixture, self.storage.create_tsigkey(self.admin_context,
|
||||
fixture)
|
||||
return fixture, self.storage.create_tsigkey(context, fixture)
|
||||
|
||||
def create_domain(self, fixture=0, values={}, context=None):
|
||||
if not context:
|
||||
context = self.admin_context
|
||||
|
||||
def create_domain(self, fixture=0, values={}):
|
||||
fixture = self.get_domain_fixture(fixture, values)
|
||||
|
||||
if 'tenant_id' not in values:
|
||||
fixture['tenant_id'] = self.admin_context.tenant_id
|
||||
if 'tenant_id' not in fixture:
|
||||
fixture['tenant_id'] = context.tenant_id
|
||||
|
||||
return fixture, self.storage.create_domain(self.admin_context,
|
||||
fixture)
|
||||
return fixture, self.storage.create_domain(context, fixture)
|
||||
|
||||
def create_record(self, domain, fixture=0, values={}, context=None):
|
||||
if not context:
|
||||
context = self.admin_context
|
||||
|
||||
def create_record(self, domain, fixture=0, values={}):
|
||||
fixture = self.get_record_fixture(domain['name'], fixture, values)
|
||||
return fixture, self.storage.create_record(self.admin_context,
|
||||
return fixture, self.storage.create_record(context,
|
||||
domain['id'],
|
||||
fixture)
|
||||
|
||||
# Quota Tests
|
||||
def test_create_quota(self):
|
||||
values = self.get_quota_fixture()
|
||||
values['tenant_id'] = self.admin_context.tenant_id
|
||||
|
||||
result = self.storage.create_quota(self.admin_context, values=values)
|
||||
|
||||
@ -60,7 +78,7 @@ class StorageTestCase(object):
|
||||
self.assertIsNotNone(result['created_at'])
|
||||
self.assertIsNone(result['updated_at'])
|
||||
|
||||
self.assertEqual(result['tenant_id'], values['tenant_id'])
|
||||
self.assertEqual(result['tenant_id'], self.admin_context.tenant_id)
|
||||
self.assertEqual(result['resource'], values['resource'])
|
||||
self.assertEqual(result['hard_limit'], values['hard_limit'])
|
||||
|
||||
@ -187,15 +205,16 @@ class StorageTestCase(object):
|
||||
self.assertEqual(updated['hard_limit'], fixture['hard_limit'])
|
||||
|
||||
def test_update_quota_duplicate(self):
|
||||
# Create two quotas
|
||||
self.create_quota(fixture=0)
|
||||
_, quota = self.create_quota(fixture=1)
|
||||
context = self.get_admin_context()
|
||||
context.all_tenants = True
|
||||
|
||||
values = self.quota_fixtures[0]
|
||||
# Create two quotas
|
||||
self.create_quota(fixture=0, values={'tenant_id': '1'})
|
||||
_, quota = self.create_quota(fixture=0, values={'tenant_id': '2'})
|
||||
|
||||
with testtools.ExpectedException(exceptions.DuplicateQuota):
|
||||
self.storage.update_quota(self.admin_context, quota['id'],
|
||||
values)
|
||||
self.storage.update_quota(context, quota['id'],
|
||||
values={'tenant_id': '1'})
|
||||
|
||||
def test_update_quota_missing(self):
|
||||
with testtools.ExpectedException(exceptions.QuotaNotFound):
|
||||
@ -460,16 +479,19 @@ class StorageTestCase(object):
|
||||
|
||||
# Tenant Tests
|
||||
def test_find_tenants(self):
|
||||
context = self.get_admin_context()
|
||||
context.all_tenants = True
|
||||
|
||||
# create 3 domains in 2 tenants
|
||||
self.create_domain(fixture=0, values={'tenant_id': 'One'})
|
||||
_, domain = self.create_domain(fixture=1, values={'tenant_id': 'One'})
|
||||
self.create_domain(fixture=2, values={'tenant_id': 'Two'})
|
||||
|
||||
# Delete one of the domains.
|
||||
self.storage.delete_domain(self.admin_context, domain['id'])
|
||||
self.storage.delete_domain(context, domain['id'])
|
||||
|
||||
# Ensure we get accurate results
|
||||
result = self.storage.find_tenants(self.admin_context)
|
||||
result = self.storage.find_tenants(context)
|
||||
|
||||
expected = [{
|
||||
'id': 'One',
|
||||
@ -482,15 +504,18 @@ class StorageTestCase(object):
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_get_tenant(self):
|
||||
context = self.get_admin_context()
|
||||
context.all_tenants = True
|
||||
|
||||
# create 2 domains in a tenant
|
||||
_, domain_1 = self.create_domain(fixture=0, values={'tenant_id': 1})
|
||||
_, domain_2 = self.create_domain(fixture=1, values={'tenant_id': 1})
|
||||
_, domain_3 = self.create_domain(fixture=2, values={'tenant_id': 1})
|
||||
|
||||
# Delete one of the domains.
|
||||
self.storage.delete_domain(self.admin_context, domain_3['id'])
|
||||
self.storage.delete_domain(context, domain_3['id'])
|
||||
|
||||
result = self.storage.get_tenant(self.admin_context, 1)
|
||||
result = self.storage.get_tenant(context, 1)
|
||||
|
||||
self.assertEqual(result['id'], 1)
|
||||
self.assertEqual(result['domain_count'], 2)
|
||||
@ -498,8 +523,11 @@ class StorageTestCase(object):
|
||||
[domain_1['name'], domain_2['name']])
|
||||
|
||||
def test_count_tenants(self):
|
||||
context = self.get_admin_context()
|
||||
context.all_tenants = True
|
||||
|
||||
# in the beginning, there should be nothing
|
||||
tenants = self.storage.count_tenants(self.admin_context)
|
||||
tenants = self.storage.count_tenants(context)
|
||||
self.assertEqual(tenants, 0)
|
||||
|
||||
# create 2 domains with 2 tenants
|
||||
@ -508,9 +536,9 @@ class StorageTestCase(object):
|
||||
_, domain = self.create_domain(fixture=2, values={'tenant_id': 2})
|
||||
|
||||
# Delete one of the domains.
|
||||
self.storage.delete_domain(self.admin_context, domain['id'])
|
||||
self.storage.delete_domain(context, domain['id'])
|
||||
|
||||
tenants = self.storage.count_tenants(self.admin_context)
|
||||
tenants = self.storage.count_tenants(context)
|
||||
self.assertEqual(tenants, 2)
|
||||
|
||||
# Domain Tests
|
||||
@ -586,6 +614,38 @@ class StorageTestCase(object):
|
||||
self.assertEqual(results[0]['email'], domain_two['email'])
|
||||
self.assertIn('status', domain_two)
|
||||
|
||||
def test_find_domains_all_tenants(self):
|
||||
# Create two contexts with different tenant_id's
|
||||
one_context = self.get_admin_context()
|
||||
one_context.tenant = 1
|
||||
two_context = self.get_admin_context()
|
||||
two_context.tenant = 2
|
||||
|
||||
# Create normal and all_tenants context objects
|
||||
nm_context = self.get_admin_context()
|
||||
at_context = self.get_admin_context()
|
||||
at_context.all_tenants = True
|
||||
|
||||
# Create two domains in different tenants
|
||||
self.create_domain(fixture=0, context=one_context)
|
||||
self.create_domain(fixture=1, context=two_context)
|
||||
|
||||
# Ensure the all_tenants context see's two domains
|
||||
results = self.storage.find_domains(at_context)
|
||||
self.assertEqual(len(results), 2)
|
||||
|
||||
# Ensure the normal context see's no domains
|
||||
results = self.storage.find_domains(nm_context)
|
||||
self.assertEqual(len(results), 0)
|
||||
|
||||
# Ensure the tenant 1 context see's 1 domain
|
||||
results = self.storage.find_domains(one_context)
|
||||
self.assertEqual(len(results), 1)
|
||||
|
||||
# Ensure the tenant 2 context see's 1 domain
|
||||
results = self.storage.find_domains(two_context)
|
||||
self.assertEqual(len(results), 1)
|
||||
|
||||
def test_get_domain(self):
|
||||
# Create a domain
|
||||
fixture, expected = self.create_domain()
|
||||
@ -604,7 +664,7 @@ class StorageTestCase(object):
|
||||
context = self.get_admin_context()
|
||||
context.show_deleted = True
|
||||
|
||||
_, domain = self.create_domain()
|
||||
_, domain = self.create_domain(context=context)
|
||||
|
||||
self.storage.delete_domain(context, domain['id'])
|
||||
self.storage.get_domain(context, domain['id'])
|
||||
@ -801,6 +861,41 @@ class StorageTestCase(object):
|
||||
|
||||
self.assertEqual(len(results), 1)
|
||||
|
||||
def test_find_records_all_tenants(self):
|
||||
# Create two contexts with different tenant_id's
|
||||
one_context = self.get_admin_context()
|
||||
one_context.tenant = 1
|
||||
two_context = self.get_admin_context()
|
||||
two_context.tenant = 2
|
||||
|
||||
# Create normal and all_tenants context objects
|
||||
nm_context = self.get_admin_context()
|
||||
at_context = self.get_admin_context()
|
||||
at_context.all_tenants = True
|
||||
|
||||
# Create two domains in different tenants, and 1 record in each
|
||||
_, domain_one = self.create_domain(fixture=0, context=one_context)
|
||||
self.create_record(domain_one, fixture=0, context=one_context)
|
||||
|
||||
_, domain_two = self.create_domain(fixture=1, context=two_context)
|
||||
self.create_record(domain_two, fixture=0, context=two_context)
|
||||
|
||||
# Ensure the all_tenants context see's two records
|
||||
results = self.storage.find_records(at_context)
|
||||
self.assertEqual(len(results), 2)
|
||||
|
||||
# Ensure the normal context see's no records
|
||||
results = self.storage.find_records(nm_context)
|
||||
self.assertEqual(len(results), 0)
|
||||
|
||||
# Ensure the tenant 1 context see's 1 record
|
||||
results = self.storage.find_records(one_context)
|
||||
self.assertEqual(len(results), 1)
|
||||
|
||||
# Ensure the tenant 2 context see's 1 record
|
||||
results = self.storage.find_records(two_context)
|
||||
self.assertEqual(len(results), 1)
|
||||
|
||||
def test_get_record(self):
|
||||
_, domain = self.create_domain()
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
{
|
||||
"admin": "role:admin or is_admin:True",
|
||||
"owner": "tenant_id:%(tenant_id)s or tenant_id:%(effective_tenant_id)s",
|
||||
"owner": "tenant_id:%(tenant_id)s",
|
||||
"admin_or_owner": "rule:admin or rule:owner",
|
||||
|
||||
"default": "rule:admin_or_owner",
|
||||
|
Loading…
x
Reference in New Issue
Block a user