add database string field length check

Added database string field length check, so when insert to a table, if the length of string field exceed the limit of column when, it will return a 400 error instead of truncating the string.

Change-Id: I7216fe736ea6e5a23b5647b107fcb2699f1fa99d
Fixes: bug #1090247
This commit is contained in:
Tony NIU 2013-01-09 20:09:40 +08:00
parent 9460ff5c35
commit 9c2c4ece64
6 changed files with 94 additions and 2 deletions

View File

@ -23,10 +23,12 @@ from sqlalchemy.ext import declarative
import sqlalchemy.orm import sqlalchemy.orm
import sqlalchemy.pool import sqlalchemy.pool
from sqlalchemy import types as sql_types from sqlalchemy import types as sql_types
from sqlalchemy.orm.attributes import InstrumentedAttribute
from keystone.common import logging from keystone.common import logging
from keystone import config from keystone import config
from keystone.openstack.common import jsonutils from keystone.openstack.common import jsonutils
from keystone import exception
CONF = config.CONF CONF = config.CONF
@ -49,6 +51,44 @@ Boolean = sql.Boolean
Text = sql.Text Text = sql.Text
def initialize_decorator(init):
"""Ensure that the length of string field do not exceed the limit.
This decorator check the initialize arguments, to make sure the
length of string field do not exceed the length limit, or raise a
'StringLengthExceeded' exception.
Use decorator instead of inheritance, because the metaclass will
check the __tablename__, primary key columns, etc. at the class
definition.
"""
def initialize(self, *args, **kwargs):
cls = type(self)
for k, v in kwargs.items():
if hasattr(cls, k):
attr = getattr(cls, k)
if isinstance(attr, InstrumentedAttribute):
column = attr.property.columns[0]
if isinstance(column.type, String):
if column.type.length and \
column.type.length < len(str(v)):
#if signing.token_format == 'PKI', the id will
#store it's public key which is very long.
if config.CONF.signing.token_format == 'PKI' and \
self.__tablename__ == 'token' and \
k == 'id':
continue
raise exception.StringLengthExceeded(
string=v, type=k, length=column.type.length)
init(self, *args, **kwargs)
return initialize
ModelBase.__init__ = initialize_decorator(ModelBase.__init__)
def set_global_engine(engine): def set_global_engine(engine):
global GLOBAL_ENGINE global GLOBAL_ENGINE
GLOBAL_ENGINE = engine GLOBAL_ENGINE = engine

View File

@ -73,6 +73,11 @@ class ValidationError(Error):
title = 'Bad Request' title = 'Bad Request'
class StringLengthExceeded(ValidationError):
"""The length of string "%(string)s" exceeded the limit of column
%(type)s(CHAR(%(length)d))."""
class SecurityError(Error): class SecurityError(Error):
"""Avoids exposing details of security failures, unless in debug mode.""" """Avoids exposing details of security failures, unless in debug mode."""

View File

@ -886,7 +886,7 @@ class CatalogTests(object):
endpoint = { endpoint = {
'id': uuid.uuid4().hex, 'id': uuid.uuid4().hex,
'region': uuid.uuid4().hex, 'region': uuid.uuid4().hex,
'interface': uuid.uuid4().hex, 'interface': uuid.uuid4().hex[:8],
'url': uuid.uuid4().hex, 'url': uuid.uuid4().hex,
'service_id': service['id'], 'service_id': service['id'],
} }
@ -934,6 +934,24 @@ class CatalogTests(object):
{}, {},
uuid.uuid4().hex) uuid.uuid4().hex)
def test_create_endpoint(self):
service = {
'id': uuid.uuid4().hex,
'type': uuid.uuid4().hex,
'name': uuid.uuid4().hex,
'description': uuid.uuid4().hex,
}
self.catalog_api.create_service(service['id'], service.copy())
endpoint = {
'id': uuid.uuid4().hex,
'region': "0" * 255,
'service_id': service['id'],
'interface': 'public',
'url': uuid.uuid4().hex,
}
self.catalog_api.create_endpoint(endpoint['id'], endpoint.copy())
class PolicyTests(object): class PolicyTests(object):
def _new_policy_ref(self): def _new_policy_ref(self):

View File

@ -264,6 +264,26 @@ class SqlCatalog(SqlTests, test_backend.CatalogTests):
self.assertIsNone(catalog_endpoint.get('adminURL')) self.assertIsNone(catalog_endpoint.get('adminURL'))
self.assertIsNone(catalog_endpoint.get('internalURL')) self.assertIsNone(catalog_endpoint.get('internalURL'))
def test_create_endpoint_400(self):
service = {
'id': uuid.uuid4().hex,
'type': uuid.uuid4().hex,
'name': uuid.uuid4().hex,
'description': uuid.uuid4().hex,
}
self.catalog_api.create_service(service['id'], service.copy())
endpoint = {
'id': uuid.uuid4().hex,
'region': "0" * 256,
'service_id': service['id'],
'interface': 'public',
'url': uuid.uuid4().hex,
}
with self.assertRaises(exception.StringLengthExceeded):
self.catalog_api.create_endpoint(endpoint['id'], endpoint.copy())
class SqlPolicy(SqlTests, test_backend.PolicyTests): class SqlPolicy(SqlTests, test_backend.PolicyTests):
pass pass

View File

@ -42,7 +42,7 @@ class RestfulTestCase(test_content_types.RestfulTestCase):
def new_endpoint_ref(self, service_id): def new_endpoint_ref(self, service_id):
ref = self.new_ref() ref = self.new_ref()
ref['interface'] = uuid.uuid4().hex ref['interface'] = uuid.uuid4().hex[:8]
ref['service_id'] = service_id ref['service_id'] = service_id
ref['url'] = uuid.uuid4().hex ref['url'] = uuid.uuid4().hex
return ref return ref

View File

@ -119,6 +119,15 @@ class CatalogTestCase(test_v3.RestfulTestCase):
body={'endpoint': ref}) body={'endpoint': ref})
self.assertValidEndpointResponse(r, ref) self.assertValidEndpointResponse(r, ref)
def assertValidErrorResponse(self, response):
self.assertTrue(response.status in [400])
def test_create_endpoint_400(self):
"""POST /endpoints"""
ref = self.new_endpoint_ref(service_id=self.service_id)
ref["region"] = "0" * 256
self.post('/endpoints', body={'endpoint': ref}, expected_status=400)
def test_get_endpoint(self): def test_get_endpoint(self):
"""GET /endpoints/{endpoint_id}""" """GET /endpoints/{endpoint_id}"""
r = self.get( r = self.get(