Merge "Moved name formatting (clean) out of the driver"

This commit is contained in:
Jenkins 2016-04-18 18:44:41 +00:00 committed by Gerrit Code Review
commit be3d75f67d
6 changed files with 45 additions and 12 deletions

View File

@ -20,7 +20,6 @@ from oslo_log import log
from oslo_log import versionutils from oslo_log import versionutils
import six import six
from keystone.common import clean
from keystone.common import driver_hints from keystone.common import driver_hints
from keystone.common import ldap as common_ldap from keystone.common import ldap as common_ldap
from keystone.common.ldap import models from keystone.common.ldap import models
@ -135,7 +134,6 @@ class Identity(base.IdentityDriverV8):
msg = _DEPRECATION_MSG % "create_group" msg = _DEPRECATION_MSG % "create_group"
versionutils.report_deprecated_feature(LOG, msg) versionutils.report_deprecated_feature(LOG, msg)
self.group.check_allow_create() self.group.check_allow_create()
group['name'] = clean.group_name(group['name'])
return common_ldap.filter_entity(self.group.create(group)) return common_ldap.filter_entity(self.group.create(group))
def get_group(self, group_id): def get_group(self, group_id):
@ -150,8 +148,6 @@ class Identity(base.IdentityDriverV8):
msg = _DEPRECATION_MSG % "update_group" msg = _DEPRECATION_MSG % "update_group"
versionutils.report_deprecated_feature(LOG, msg) versionutils.report_deprecated_feature(LOG, msg)
self.group.check_allow_update() self.group.check_allow_update()
if 'name' in group:
group['name'] = clean.group_name(group['name'])
return common_ldap.filter_entity(self.group.update(group_id, group)) return common_ldap.filter_entity(self.group.update(group_id, group))
def delete_group(self, group_id): def delete_group(self, group_id):

View File

@ -983,6 +983,7 @@ class Manager(manager.Manager):
# the underlying driver so that it could conform to rules set down by # the underlying driver so that it could conform to rules set down by
# that particular driver type. # that particular driver type.
group['id'] = uuid.uuid4().hex group['id'] = uuid.uuid4().hex
group['name'] = clean.group_name(group['name'])
ref = driver.create_group(group['id'], group) ref = driver.create_group(group['id'], group)
notifications.Audit.created(self._GROUP, group['id'], initiator) notifications.Audit.created(self._GROUP, group['id'], initiator)
@ -1019,6 +1020,8 @@ class Manager(manager.Manager):
domain_id, driver, entity_id = ( domain_id, driver, entity_id = (
self._get_domain_driver_and_entity_id(group_id)) self._get_domain_driver_and_entity_id(group_id))
group = self._clear_domain_id_if_domain_unaware(driver, group) group = self._clear_domain_id_if_domain_unaware(driver, group)
if 'name' in group:
group['name'] = clean.group_name(group['name'])
ref = driver.update_group(entity_id, group) ref = driver.update_group(entity_id, group)
self.get_group.invalidate(self, group_id) self.get_group.invalidate(self, group_id)
notifications.Audit.updated(self._GROUP, group_id, initiator) notifications.Audit.updated(self._GROUP, group_id, initiator)

View File

@ -12,7 +12,6 @@
from oslo_log import log from oslo_log import log
from keystone.common import clean
from keystone.common import driver_hints from keystone.common import driver_hints
from keystone.common import sql from keystone.common import sql
from keystone import exception from keystone import exception
@ -174,7 +173,6 @@ class Resource(keystone_resource.ResourceDriverV9):
# CRUD # CRUD
@sql.handle_conflicts(conflict_type='project') @sql.handle_conflicts(conflict_type='project')
def create_project(self, project_id, project): def create_project(self, project_id, project):
project['name'] = clean.project_name(project['name'])
new_project = self._encode_domain_id(project) new_project = self._encode_domain_id(project)
with sql.session_for_write() as session: with sql.session_for_write() as session:
project_ref = Project.from_dict(new_project) project_ref = Project.from_dict(new_project)
@ -183,9 +181,6 @@ class Resource(keystone_resource.ResourceDriverV9):
@sql.handle_conflicts(conflict_type='project') @sql.handle_conflicts(conflict_type='project')
def update_project(self, project_id, project): def update_project(self, project_id, project):
if 'name' in project:
project['name'] = clean.project_name(project['name'])
update_project = self._encode_domain_id(project) update_project = self._encode_domain_id(project)
with sql.session_for_write() as session: with sql.session_for_write() as session:
project_ref = self._get_project(session, project_id) project_ref = self._get_project(session, project_id)

View File

@ -216,6 +216,7 @@ class Manager(manager.Manager):
project.setdefault('enabled', True) project.setdefault('enabled', True)
project['enabled'] = clean.project_enabled(project['enabled']) project['enabled'] = clean.project_enabled(project['enabled'])
project['name'] = clean.project_name(project['name'])
project.setdefault('description', '') project.setdefault('description', '')
# For regular projects, the controller will ensure we have a valid # For regular projects, the controller will ensure we have a valid
@ -327,12 +328,14 @@ class Manager(manager.Manager):
url_safe_option = CONF.resource.project_name_url_safe url_safe_option = CONF.resource.project_name_url_safe
exception_entity = 'Project' exception_entity = 'Project'
if (url_safe_option != 'off' and project_name_changed = ('name' in project and project['name'] !=
'name' in project and original_project['name'])
project['name'] != original_project['name'] and if (url_safe_option != 'off' and project_name_changed and
utils.is_not_url_safe(project['name'])): utils.is_not_url_safe(project['name'])):
self._raise_reserved_character_exception(exception_entity, self._raise_reserved_character_exception(exception_entity,
project['name']) project['name'])
elif project_name_changed:
project['name'] = clean.project_name(project['name'])
parent_id = original_project.get('parent_id') parent_id = original_project.get('parent_id')
if 'parent_id' in project and project.get('parent_id') != parent_id: if 'parent_id' in project and project.get('parent_id') != parent_id:

View File

@ -804,6 +804,21 @@ class IdentityTests(object):
self.identity_api.get_group, self.identity_api.get_group,
group['id']) group['id'])
def test_create_group_name_with_trailing_whitespace(self):
group = unit.new_group_ref(domain_id=CONF.identity.default_domain_id)
group_name = group['name'] = (group['name'] + ' ')
group_returned = self.identity_api.create_group(group)
self.assertEqual(group_returned['name'], group_name.strip())
def test_update_group_name_with_trailing_whitespace(self):
group = unit.new_group_ref(domain_id=CONF.identity.default_domain_id)
group_create = self.identity_api.create_group(group)
group_name = group['name'] = (group['name'] + ' ')
group_update = self.identity_api.update_group(group_create['id'],
group)
self.assertEqual(group_update['id'], group_create['id'])
self.assertEqual(group_update['name'], group_name.strip())
def test_get_group_by_name(self): def test_get_group_by_name(self):
group = unit.new_group_ref(domain_id=CONF.identity.default_domain_id) group = unit.new_group_ref(domain_id=CONF.identity.default_domain_id)
group_name = group['name'] group_name = group['name']

View File

@ -99,6 +99,16 @@ class ResourceTests(object):
project['id'], project['id'],
project) project)
def test_create_project_name_with_trailing_whitespace(self):
project = unit.new_project_ref(
domain_id=CONF.identity.default_domain_id)
project_id = project['id']
project_name = project['name'] = (project['name'] + ' ')
project_returned = self.resource_api.create_project(project_id,
project)
self.assertEqual(project_returned['id'], project_id)
self.assertEqual(project_returned['name'], project_name.strip())
def test_create_duplicate_project_name_in_different_domains(self): def test_create_duplicate_project_name_in_different_domains(self):
new_domain = unit.new_domain_ref() new_domain = unit.new_domain_ref()
self.resource_api.create_domain(new_domain['id'], new_domain) self.resource_api.create_domain(new_domain['id'], new_domain)
@ -241,6 +251,17 @@ class ResourceTests(object):
self.resource_api.get_project, self.resource_api.get_project,
'fake2') 'fake2')
def test_update_project_name_with_trailing_whitespace(self):
project = unit.new_project_ref(
domain_id=CONF.identity.default_domain_id)
project_id = project['id']
project_create = self.resource_api.create_project(project_id, project)
self.assertEqual(project_create['id'], project_id)
project_name = project['name'] = (project['name'] + ' ')
project_update = self.resource_api.update_project(project_id, project)
self.assertEqual(project_update['id'], project_id)
self.assertEqual(project_update['name'], project_name.strip())
def test_delete_domain_with_user_group_project_links(self): def test_delete_domain_with_user_group_project_links(self):
# TODO(chungg):add test case once expected behaviour defined # TODO(chungg):add test case once expected behaviour defined
pass pass