add database string field length check

Added database string field length check, so when insert to a table, if the length of string field exceed the limit of column when, it will return a 400 error instead of truncating the string.

Change-Id: I7216fe736ea6e5a23b5647b107fcb2699f1fa99d
Fixes: bug #1090247
This commit is contained in:
Tony NIU 2013-01-09 20:09:40 +08:00
parent 9460ff5c35
commit 9c2c4ece64
6 changed files with 94 additions and 2 deletions

View File

@ -23,10 +23,12 @@ from sqlalchemy.ext import declarative
import sqlalchemy.orm
import sqlalchemy.pool
from sqlalchemy import types as sql_types
from sqlalchemy.orm.attributes import InstrumentedAttribute
from keystone.common import logging
from keystone import config
from keystone.openstack.common import jsonutils
from keystone import exception
CONF = config.CONF
@ -49,6 +51,44 @@ Boolean = sql.Boolean
Text = sql.Text
def initialize_decorator(init):
"""Ensure that the length of string field do not exceed the limit.
This decorator check the initialize arguments, to make sure the
length of string field do not exceed the length limit, or raise a
'StringLengthExceeded' exception.
Use decorator instead of inheritance, because the metaclass will
check the __tablename__, primary key columns, etc. at the class
definition.
"""
def initialize(self, *args, **kwargs):
cls = type(self)
for k, v in kwargs.items():
if hasattr(cls, k):
attr = getattr(cls, k)
if isinstance(attr, InstrumentedAttribute):
column = attr.property.columns[0]
if isinstance(column.type, String):
if column.type.length and \
column.type.length < len(str(v)):
#if signing.token_format == 'PKI', the id will
#store it's public key which is very long.
if config.CONF.signing.token_format == 'PKI' and \
self.__tablename__ == 'token' and \
k == 'id':
continue
raise exception.StringLengthExceeded(
string=v, type=k, length=column.type.length)
init(self, *args, **kwargs)
return initialize
ModelBase.__init__ = initialize_decorator(ModelBase.__init__)
def set_global_engine(engine):
global GLOBAL_ENGINE
GLOBAL_ENGINE = engine

View File

@ -73,6 +73,11 @@ class ValidationError(Error):
title = 'Bad Request'
class StringLengthExceeded(ValidationError):
"""The length of string "%(string)s" exceeded the limit of column
%(type)s(CHAR(%(length)d))."""
class SecurityError(Error):
"""Avoids exposing details of security failures, unless in debug mode."""

View File

@ -886,7 +886,7 @@ class CatalogTests(object):
endpoint = {
'id': uuid.uuid4().hex,
'region': uuid.uuid4().hex,
'interface': uuid.uuid4().hex,
'interface': uuid.uuid4().hex[:8],
'url': uuid.uuid4().hex,
'service_id': service['id'],
}
@ -934,6 +934,24 @@ class CatalogTests(object):
{},
uuid.uuid4().hex)
def test_create_endpoint(self):
service = {
'id': uuid.uuid4().hex,
'type': uuid.uuid4().hex,
'name': uuid.uuid4().hex,
'description': uuid.uuid4().hex,
}
self.catalog_api.create_service(service['id'], service.copy())
endpoint = {
'id': uuid.uuid4().hex,
'region': "0" * 255,
'service_id': service['id'],
'interface': 'public',
'url': uuid.uuid4().hex,
}
self.catalog_api.create_endpoint(endpoint['id'], endpoint.copy())
class PolicyTests(object):
def _new_policy_ref(self):

View File

@ -264,6 +264,26 @@ class SqlCatalog(SqlTests, test_backend.CatalogTests):
self.assertIsNone(catalog_endpoint.get('adminURL'))
self.assertIsNone(catalog_endpoint.get('internalURL'))
def test_create_endpoint_400(self):
service = {
'id': uuid.uuid4().hex,
'type': uuid.uuid4().hex,
'name': uuid.uuid4().hex,
'description': uuid.uuid4().hex,
}
self.catalog_api.create_service(service['id'], service.copy())
endpoint = {
'id': uuid.uuid4().hex,
'region': "0" * 256,
'service_id': service['id'],
'interface': 'public',
'url': uuid.uuid4().hex,
}
with self.assertRaises(exception.StringLengthExceeded):
self.catalog_api.create_endpoint(endpoint['id'], endpoint.copy())
class SqlPolicy(SqlTests, test_backend.PolicyTests):
pass

View File

@ -42,7 +42,7 @@ class RestfulTestCase(test_content_types.RestfulTestCase):
def new_endpoint_ref(self, service_id):
ref = self.new_ref()
ref['interface'] = uuid.uuid4().hex
ref['interface'] = uuid.uuid4().hex[:8]
ref['service_id'] = service_id
ref['url'] = uuid.uuid4().hex
return ref

View File

@ -119,6 +119,15 @@ class CatalogTestCase(test_v3.RestfulTestCase):
body={'endpoint': ref})
self.assertValidEndpointResponse(r, ref)
def assertValidErrorResponse(self, response):
self.assertTrue(response.status in [400])
def test_create_endpoint_400(self):
"""POST /endpoints"""
ref = self.new_endpoint_ref(service_id=self.service_id)
ref["region"] = "0" * 256
self.post('/endpoints', body={'endpoint': ref}, expected_status=400)
def test_get_endpoint(self):
"""GET /endpoints/{endpoint_id}"""
r = self.get(