Merge "add database string field length check"
This commit is contained in:
commit
e1abe0fca3
@ -23,10 +23,12 @@ from sqlalchemy.ext import declarative
|
||||
import sqlalchemy.orm
|
||||
import sqlalchemy.pool
|
||||
from sqlalchemy import types as sql_types
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
|
||||
from keystone.common import logging
|
||||
from keystone import config
|
||||
from keystone.openstack.common import jsonutils
|
||||
from keystone import exception
|
||||
|
||||
|
||||
CONF = config.CONF
|
||||
@ -49,6 +51,44 @@ Boolean = sql.Boolean
|
||||
Text = sql.Text
|
||||
|
||||
|
||||
def initialize_decorator(init):
|
||||
"""Ensure that the length of string field do not exceed the limit.
|
||||
|
||||
This decorator check the initialize arguments, to make sure the
|
||||
length of string field do not exceed the length limit, or raise a
|
||||
'StringLengthExceeded' exception.
|
||||
|
||||
Use decorator instead of inheritance, because the metaclass will
|
||||
check the __tablename__, primary key columns, etc. at the class
|
||||
definition.
|
||||
|
||||
"""
|
||||
def initialize(self, *args, **kwargs):
|
||||
cls = type(self)
|
||||
for k, v in kwargs.items():
|
||||
if hasattr(cls, k):
|
||||
attr = getattr(cls, k)
|
||||
if isinstance(attr, InstrumentedAttribute):
|
||||
column = attr.property.columns[0]
|
||||
if isinstance(column.type, String):
|
||||
if column.type.length and \
|
||||
column.type.length < len(str(v)):
|
||||
#if signing.token_format == 'PKI', the id will
|
||||
#store it's public key which is very long.
|
||||
if config.CONF.signing.token_format == 'PKI' and \
|
||||
self.__tablename__ == 'token' and \
|
||||
k == 'id':
|
||||
continue
|
||||
|
||||
raise exception.StringLengthExceeded(
|
||||
string=v, type=k, length=column.type.length)
|
||||
|
||||
init(self, *args, **kwargs)
|
||||
return initialize
|
||||
|
||||
ModelBase.__init__ = initialize_decorator(ModelBase.__init__)
|
||||
|
||||
|
||||
def set_global_engine(engine):
|
||||
global GLOBAL_ENGINE
|
||||
GLOBAL_ENGINE = engine
|
||||
|
@ -79,6 +79,11 @@ class ValidationError(Error):
|
||||
title = 'Bad Request'
|
||||
|
||||
|
||||
class StringLengthExceeded(ValidationError):
|
||||
"""The length of string "%(string)s" exceeded the limit of column
|
||||
%(type)s(CHAR(%(length)d))."""
|
||||
|
||||
|
||||
class SecurityError(Error):
|
||||
"""Avoids exposing details of security failures, unless in debug mode."""
|
||||
|
||||
|
@ -1192,7 +1192,7 @@ class CatalogTests(object):
|
||||
endpoint = {
|
||||
'id': uuid.uuid4().hex,
|
||||
'region': uuid.uuid4().hex,
|
||||
'interface': uuid.uuid4().hex,
|
||||
'interface': uuid.uuid4().hex[:8],
|
||||
'url': uuid.uuid4().hex,
|
||||
'service_id': service['id'],
|
||||
}
|
||||
@ -1240,6 +1240,24 @@ class CatalogTests(object):
|
||||
{},
|
||||
uuid.uuid4().hex)
|
||||
|
||||
def test_create_endpoint(self):
|
||||
service = {
|
||||
'id': uuid.uuid4().hex,
|
||||
'type': uuid.uuid4().hex,
|
||||
'name': uuid.uuid4().hex,
|
||||
'description': uuid.uuid4().hex,
|
||||
}
|
||||
self.catalog_api.create_service(service['id'], service.copy())
|
||||
|
||||
endpoint = {
|
||||
'id': uuid.uuid4().hex,
|
||||
'region': "0" * 255,
|
||||
'service_id': service['id'],
|
||||
'interface': 'public',
|
||||
'url': uuid.uuid4().hex,
|
||||
}
|
||||
self.catalog_api.create_endpoint(endpoint['id'], endpoint.copy())
|
||||
|
||||
|
||||
class PolicyTests(object):
|
||||
def _new_policy_ref(self):
|
||||
|
@ -264,6 +264,26 @@ class SqlCatalog(SqlTests, test_backend.CatalogTests):
|
||||
self.assertIsNone(catalog_endpoint.get('adminURL'))
|
||||
self.assertIsNone(catalog_endpoint.get('internalURL'))
|
||||
|
||||
def test_create_endpoint_400(self):
|
||||
service = {
|
||||
'id': uuid.uuid4().hex,
|
||||
'type': uuid.uuid4().hex,
|
||||
'name': uuid.uuid4().hex,
|
||||
'description': uuid.uuid4().hex,
|
||||
}
|
||||
self.catalog_api.create_service(service['id'], service.copy())
|
||||
|
||||
endpoint = {
|
||||
'id': uuid.uuid4().hex,
|
||||
'region': "0" * 256,
|
||||
'service_id': service['id'],
|
||||
'interface': 'public',
|
||||
'url': uuid.uuid4().hex,
|
||||
}
|
||||
|
||||
with self.assertRaises(exception.StringLengthExceeded):
|
||||
self.catalog_api.create_endpoint(endpoint['id'], endpoint.copy())
|
||||
|
||||
|
||||
class SqlPolicy(SqlTests, test_backend.PolicyTests):
|
||||
pass
|
||||
|
@ -42,7 +42,7 @@ class RestfulTestCase(test_content_types.RestfulTestCase):
|
||||
|
||||
def new_endpoint_ref(self, service_id):
|
||||
ref = self.new_ref()
|
||||
ref['interface'] = uuid.uuid4().hex
|
||||
ref['interface'] = uuid.uuid4().hex[:8]
|
||||
ref['service_id'] = service_id
|
||||
ref['url'] = uuid.uuid4().hex
|
||||
return ref
|
||||
|
@ -119,6 +119,15 @@ class CatalogTestCase(test_v3.RestfulTestCase):
|
||||
body={'endpoint': ref})
|
||||
self.assertValidEndpointResponse(r, ref)
|
||||
|
||||
def assertValidErrorResponse(self, response):
|
||||
self.assertTrue(response.status in [400])
|
||||
|
||||
def test_create_endpoint_400(self):
|
||||
"""POST /endpoints"""
|
||||
ref = self.new_endpoint_ref(service_id=self.service_id)
|
||||
ref["region"] = "0" * 256
|
||||
self.post('/endpoints', body={'endpoint': ref}, expected_status=400)
|
||||
|
||||
def test_get_endpoint(self):
|
||||
"""GET /endpoints/{endpoint_id}"""
|
||||
r = self.get(
|
||||
|
Loading…
Reference in New Issue
Block a user