Merge "add database string field length check"

This commit is contained in:
Jenkins 2013-01-15 02:33:15 +00:00 committed by Gerrit Code Review
commit e1abe0fca3
6 changed files with 94 additions and 2 deletions

View File

@ -23,10 +23,12 @@ from sqlalchemy.ext import declarative
import sqlalchemy.orm
import sqlalchemy.pool
from sqlalchemy import types as sql_types
from sqlalchemy.orm.attributes import InstrumentedAttribute
from keystone.common import logging
from keystone import config
from keystone.openstack.common import jsonutils
from keystone import exception
CONF = config.CONF
@ -49,6 +51,44 @@ Boolean = sql.Boolean
Text = sql.Text
def initialize_decorator(init):
"""Ensure that the length of string field do not exceed the limit.
This decorator check the initialize arguments, to make sure the
length of string field do not exceed the length limit, or raise a
'StringLengthExceeded' exception.
Use decorator instead of inheritance, because the metaclass will
check the __tablename__, primary key columns, etc. at the class
definition.
"""
def initialize(self, *args, **kwargs):
cls = type(self)
for k, v in kwargs.items():
if hasattr(cls, k):
attr = getattr(cls, k)
if isinstance(attr, InstrumentedAttribute):
column = attr.property.columns[0]
if isinstance(column.type, String):
if column.type.length and \
column.type.length < len(str(v)):
#if signing.token_format == 'PKI', the id will
#store it's public key which is very long.
if config.CONF.signing.token_format == 'PKI' and \
self.__tablename__ == 'token' and \
k == 'id':
continue
raise exception.StringLengthExceeded(
string=v, type=k, length=column.type.length)
init(self, *args, **kwargs)
return initialize
ModelBase.__init__ = initialize_decorator(ModelBase.__init__)
def set_global_engine(engine):
global GLOBAL_ENGINE
GLOBAL_ENGINE = engine

View File

@ -79,6 +79,11 @@ class ValidationError(Error):
title = 'Bad Request'
class StringLengthExceeded(ValidationError):
"""The length of string "%(string)s" exceeded the limit of column
%(type)s(CHAR(%(length)d))."""
class SecurityError(Error):
"""Avoids exposing details of security failures, unless in debug mode."""

View File

@ -1192,7 +1192,7 @@ class CatalogTests(object):
endpoint = {
'id': uuid.uuid4().hex,
'region': uuid.uuid4().hex,
'interface': uuid.uuid4().hex,
'interface': uuid.uuid4().hex[:8],
'url': uuid.uuid4().hex,
'service_id': service['id'],
}
@ -1240,6 +1240,24 @@ class CatalogTests(object):
{},
uuid.uuid4().hex)
def test_create_endpoint(self):
service = {
'id': uuid.uuid4().hex,
'type': uuid.uuid4().hex,
'name': uuid.uuid4().hex,
'description': uuid.uuid4().hex,
}
self.catalog_api.create_service(service['id'], service.copy())
endpoint = {
'id': uuid.uuid4().hex,
'region': "0" * 255,
'service_id': service['id'],
'interface': 'public',
'url': uuid.uuid4().hex,
}
self.catalog_api.create_endpoint(endpoint['id'], endpoint.copy())
class PolicyTests(object):
def _new_policy_ref(self):

View File

@ -264,6 +264,26 @@ class SqlCatalog(SqlTests, test_backend.CatalogTests):
self.assertIsNone(catalog_endpoint.get('adminURL'))
self.assertIsNone(catalog_endpoint.get('internalURL'))
def test_create_endpoint_400(self):
service = {
'id': uuid.uuid4().hex,
'type': uuid.uuid4().hex,
'name': uuid.uuid4().hex,
'description': uuid.uuid4().hex,
}
self.catalog_api.create_service(service['id'], service.copy())
endpoint = {
'id': uuid.uuid4().hex,
'region': "0" * 256,
'service_id': service['id'],
'interface': 'public',
'url': uuid.uuid4().hex,
}
with self.assertRaises(exception.StringLengthExceeded):
self.catalog_api.create_endpoint(endpoint['id'], endpoint.copy())
class SqlPolicy(SqlTests, test_backend.PolicyTests):
pass

View File

@ -42,7 +42,7 @@ class RestfulTestCase(test_content_types.RestfulTestCase):
def new_endpoint_ref(self, service_id):
ref = self.new_ref()
ref['interface'] = uuid.uuid4().hex
ref['interface'] = uuid.uuid4().hex[:8]
ref['service_id'] = service_id
ref['url'] = uuid.uuid4().hex
return ref

View File

@ -119,6 +119,15 @@ class CatalogTestCase(test_v3.RestfulTestCase):
body={'endpoint': ref})
self.assertValidEndpointResponse(r, ref)
def assertValidErrorResponse(self, response):
self.assertTrue(response.status in [400])
def test_create_endpoint_400(self):
"""POST /endpoints"""
ref = self.new_endpoint_ref(service_id=self.service_id)
ref["region"] = "0" * 256
self.post('/endpoints', body={'endpoint': ref}, expected_status=400)
def test_get_endpoint(self):
"""GET /endpoints/{endpoint_id}"""
r = self.get(