add database string field length check
Added database string field length check, so when insert to a table, if the length of string field exceed the limit of column when, it will return a 400 error instead of truncating the string. Change-Id: I7216fe736ea6e5a23b5647b107fcb2699f1fa99d Fixes: bug #1090247
This commit is contained in:
parent
9460ff5c35
commit
9c2c4ece64
@ -23,10 +23,12 @@ from sqlalchemy.ext import declarative
|
|||||||
import sqlalchemy.orm
|
import sqlalchemy.orm
|
||||||
import sqlalchemy.pool
|
import sqlalchemy.pool
|
||||||
from sqlalchemy import types as sql_types
|
from sqlalchemy import types as sql_types
|
||||||
|
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||||
|
|
||||||
from keystone.common import logging
|
from keystone.common import logging
|
||||||
from keystone import config
|
from keystone import config
|
||||||
from keystone.openstack.common import jsonutils
|
from keystone.openstack.common import jsonutils
|
||||||
|
from keystone import exception
|
||||||
|
|
||||||
|
|
||||||
CONF = config.CONF
|
CONF = config.CONF
|
||||||
@ -49,6 +51,44 @@ Boolean = sql.Boolean
|
|||||||
Text = sql.Text
|
Text = sql.Text
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_decorator(init):
|
||||||
|
"""Ensure that the length of string field do not exceed the limit.
|
||||||
|
|
||||||
|
This decorator check the initialize arguments, to make sure the
|
||||||
|
length of string field do not exceed the length limit, or raise a
|
||||||
|
'StringLengthExceeded' exception.
|
||||||
|
|
||||||
|
Use decorator instead of inheritance, because the metaclass will
|
||||||
|
check the __tablename__, primary key columns, etc. at the class
|
||||||
|
definition.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def initialize(self, *args, **kwargs):
|
||||||
|
cls = type(self)
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if hasattr(cls, k):
|
||||||
|
attr = getattr(cls, k)
|
||||||
|
if isinstance(attr, InstrumentedAttribute):
|
||||||
|
column = attr.property.columns[0]
|
||||||
|
if isinstance(column.type, String):
|
||||||
|
if column.type.length and \
|
||||||
|
column.type.length < len(str(v)):
|
||||||
|
#if signing.token_format == 'PKI', the id will
|
||||||
|
#store it's public key which is very long.
|
||||||
|
if config.CONF.signing.token_format == 'PKI' and \
|
||||||
|
self.__tablename__ == 'token' and \
|
||||||
|
k == 'id':
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise exception.StringLengthExceeded(
|
||||||
|
string=v, type=k, length=column.type.length)
|
||||||
|
|
||||||
|
init(self, *args, **kwargs)
|
||||||
|
return initialize
|
||||||
|
|
||||||
|
ModelBase.__init__ = initialize_decorator(ModelBase.__init__)
|
||||||
|
|
||||||
|
|
||||||
def set_global_engine(engine):
|
def set_global_engine(engine):
|
||||||
global GLOBAL_ENGINE
|
global GLOBAL_ENGINE
|
||||||
GLOBAL_ENGINE = engine
|
GLOBAL_ENGINE = engine
|
||||||
|
@ -73,6 +73,11 @@ class ValidationError(Error):
|
|||||||
title = 'Bad Request'
|
title = 'Bad Request'
|
||||||
|
|
||||||
|
|
||||||
|
class StringLengthExceeded(ValidationError):
|
||||||
|
"""The length of string "%(string)s" exceeded the limit of column
|
||||||
|
%(type)s(CHAR(%(length)d))."""
|
||||||
|
|
||||||
|
|
||||||
class SecurityError(Error):
|
class SecurityError(Error):
|
||||||
"""Avoids exposing details of security failures, unless in debug mode."""
|
"""Avoids exposing details of security failures, unless in debug mode."""
|
||||||
|
|
||||||
|
@ -886,7 +886,7 @@ class CatalogTests(object):
|
|||||||
endpoint = {
|
endpoint = {
|
||||||
'id': uuid.uuid4().hex,
|
'id': uuid.uuid4().hex,
|
||||||
'region': uuid.uuid4().hex,
|
'region': uuid.uuid4().hex,
|
||||||
'interface': uuid.uuid4().hex,
|
'interface': uuid.uuid4().hex[:8],
|
||||||
'url': uuid.uuid4().hex,
|
'url': uuid.uuid4().hex,
|
||||||
'service_id': service['id'],
|
'service_id': service['id'],
|
||||||
}
|
}
|
||||||
@ -934,6 +934,24 @@ class CatalogTests(object):
|
|||||||
{},
|
{},
|
||||||
uuid.uuid4().hex)
|
uuid.uuid4().hex)
|
||||||
|
|
||||||
|
def test_create_endpoint(self):
|
||||||
|
service = {
|
||||||
|
'id': uuid.uuid4().hex,
|
||||||
|
'type': uuid.uuid4().hex,
|
||||||
|
'name': uuid.uuid4().hex,
|
||||||
|
'description': uuid.uuid4().hex,
|
||||||
|
}
|
||||||
|
self.catalog_api.create_service(service['id'], service.copy())
|
||||||
|
|
||||||
|
endpoint = {
|
||||||
|
'id': uuid.uuid4().hex,
|
||||||
|
'region': "0" * 255,
|
||||||
|
'service_id': service['id'],
|
||||||
|
'interface': 'public',
|
||||||
|
'url': uuid.uuid4().hex,
|
||||||
|
}
|
||||||
|
self.catalog_api.create_endpoint(endpoint['id'], endpoint.copy())
|
||||||
|
|
||||||
|
|
||||||
class PolicyTests(object):
|
class PolicyTests(object):
|
||||||
def _new_policy_ref(self):
|
def _new_policy_ref(self):
|
||||||
|
@ -264,6 +264,26 @@ class SqlCatalog(SqlTests, test_backend.CatalogTests):
|
|||||||
self.assertIsNone(catalog_endpoint.get('adminURL'))
|
self.assertIsNone(catalog_endpoint.get('adminURL'))
|
||||||
self.assertIsNone(catalog_endpoint.get('internalURL'))
|
self.assertIsNone(catalog_endpoint.get('internalURL'))
|
||||||
|
|
||||||
|
def test_create_endpoint_400(self):
|
||||||
|
service = {
|
||||||
|
'id': uuid.uuid4().hex,
|
||||||
|
'type': uuid.uuid4().hex,
|
||||||
|
'name': uuid.uuid4().hex,
|
||||||
|
'description': uuid.uuid4().hex,
|
||||||
|
}
|
||||||
|
self.catalog_api.create_service(service['id'], service.copy())
|
||||||
|
|
||||||
|
endpoint = {
|
||||||
|
'id': uuid.uuid4().hex,
|
||||||
|
'region': "0" * 256,
|
||||||
|
'service_id': service['id'],
|
||||||
|
'interface': 'public',
|
||||||
|
'url': uuid.uuid4().hex,
|
||||||
|
}
|
||||||
|
|
||||||
|
with self.assertRaises(exception.StringLengthExceeded):
|
||||||
|
self.catalog_api.create_endpoint(endpoint['id'], endpoint.copy())
|
||||||
|
|
||||||
|
|
||||||
class SqlPolicy(SqlTests, test_backend.PolicyTests):
|
class SqlPolicy(SqlTests, test_backend.PolicyTests):
|
||||||
pass
|
pass
|
||||||
|
@ -42,7 +42,7 @@ class RestfulTestCase(test_content_types.RestfulTestCase):
|
|||||||
|
|
||||||
def new_endpoint_ref(self, service_id):
|
def new_endpoint_ref(self, service_id):
|
||||||
ref = self.new_ref()
|
ref = self.new_ref()
|
||||||
ref['interface'] = uuid.uuid4().hex
|
ref['interface'] = uuid.uuid4().hex[:8]
|
||||||
ref['service_id'] = service_id
|
ref['service_id'] = service_id
|
||||||
ref['url'] = uuid.uuid4().hex
|
ref['url'] = uuid.uuid4().hex
|
||||||
return ref
|
return ref
|
||||||
|
@ -119,6 +119,15 @@ class CatalogTestCase(test_v3.RestfulTestCase):
|
|||||||
body={'endpoint': ref})
|
body={'endpoint': ref})
|
||||||
self.assertValidEndpointResponse(r, ref)
|
self.assertValidEndpointResponse(r, ref)
|
||||||
|
|
||||||
|
def assertValidErrorResponse(self, response):
|
||||||
|
self.assertTrue(response.status in [400])
|
||||||
|
|
||||||
|
def test_create_endpoint_400(self):
|
||||||
|
"""POST /endpoints"""
|
||||||
|
ref = self.new_endpoint_ref(service_id=self.service_id)
|
||||||
|
ref["region"] = "0" * 256
|
||||||
|
self.post('/endpoints', body={'endpoint': ref}, expected_status=400)
|
||||||
|
|
||||||
def test_get_endpoint(self):
|
def test_get_endpoint(self):
|
||||||
"""GET /endpoints/{endpoint_id}"""
|
"""GET /endpoints/{endpoint_id}"""
|
||||||
r = self.get(
|
r = self.get(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user