diff --git a/keystone/common/sql/core.py b/keystone/common/sql/core.py index 10629fc340..3634c75a3e 100644 --- a/keystone/common/sql/core.py +++ b/keystone/common/sql/core.py @@ -23,10 +23,12 @@ from sqlalchemy.ext import declarative import sqlalchemy.orm import sqlalchemy.pool from sqlalchemy import types as sql_types +from sqlalchemy.orm.attributes import InstrumentedAttribute from keystone.common import logging from keystone import config from keystone.openstack.common import jsonutils +from keystone import exception CONF = config.CONF @@ -49,6 +51,44 @@ Boolean = sql.Boolean Text = sql.Text +def initialize_decorator(init): + """Ensure that the length of string field do not exceed the limit. + + This decorator check the initialize arguments, to make sure the + length of string field do not exceed the length limit, or raise a + 'StringLengthExceeded' exception. + + Use decorator instead of inheritance, because the metaclass will + check the __tablename__, primary key columns, etc. at the class + definition. + + """ + def initialize(self, *args, **kwargs): + cls = type(self) + for k, v in kwargs.items(): + if hasattr(cls, k): + attr = getattr(cls, k) + if isinstance(attr, InstrumentedAttribute): + column = attr.property.columns[0] + if isinstance(column.type, String): + if column.type.length and \ + column.type.length < len(str(v)): + #if signing.token_format == 'PKI', the id will + #store it's public key which is very long. + if config.CONF.signing.token_format == 'PKI' and \ + self.__tablename__ == 'token' and \ + k == 'id': + continue + + raise exception.StringLengthExceeded( + string=v, type=k, length=column.type.length) + + init(self, *args, **kwargs) + return initialize + +ModelBase.__init__ = initialize_decorator(ModelBase.__init__) + + def set_global_engine(engine): global GLOBAL_ENGINE GLOBAL_ENGINE = engine diff --git a/keystone/exception.py b/keystone/exception.py index 5a545fc359..26697e0db4 100644 --- a/keystone/exception.py +++ b/keystone/exception.py @@ -79,6 +79,11 @@ class ValidationError(Error): title = 'Bad Request' +class StringLengthExceeded(ValidationError): + """The length of string "%(string)s" exceeded the limit of column + %(type)s(CHAR(%(length)d)).""" + + class SecurityError(Error): """Avoids exposing details of security failures, unless in debug mode.""" diff --git a/tests/test_backend.py b/tests/test_backend.py index 57e41df45b..80b709d731 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -1192,7 +1192,7 @@ class CatalogTests(object): endpoint = { 'id': uuid.uuid4().hex, 'region': uuid.uuid4().hex, - 'interface': uuid.uuid4().hex, + 'interface': uuid.uuid4().hex[:8], 'url': uuid.uuid4().hex, 'service_id': service['id'], } @@ -1240,6 +1240,24 @@ class CatalogTests(object): {}, uuid.uuid4().hex) + def test_create_endpoint(self): + service = { + 'id': uuid.uuid4().hex, + 'type': uuid.uuid4().hex, + 'name': uuid.uuid4().hex, + 'description': uuid.uuid4().hex, + } + self.catalog_api.create_service(service['id'], service.copy()) + + endpoint = { + 'id': uuid.uuid4().hex, + 'region': "0" * 255, + 'service_id': service['id'], + 'interface': 'public', + 'url': uuid.uuid4().hex, + } + self.catalog_api.create_endpoint(endpoint['id'], endpoint.copy()) + class PolicyTests(object): def _new_policy_ref(self): diff --git a/tests/test_backend_sql.py b/tests/test_backend_sql.py index 925ea9126c..cff7788c5a 100644 --- a/tests/test_backend_sql.py +++ b/tests/test_backend_sql.py @@ -264,6 +264,26 @@ class SqlCatalog(SqlTests, test_backend.CatalogTests): self.assertIsNone(catalog_endpoint.get('adminURL')) self.assertIsNone(catalog_endpoint.get('internalURL')) + def test_create_endpoint_400(self): + service = { + 'id': uuid.uuid4().hex, + 'type': uuid.uuid4().hex, + 'name': uuid.uuid4().hex, + 'description': uuid.uuid4().hex, + } + self.catalog_api.create_service(service['id'], service.copy()) + + endpoint = { + 'id': uuid.uuid4().hex, + 'region': "0" * 256, + 'service_id': service['id'], + 'interface': 'public', + 'url': uuid.uuid4().hex, + } + + with self.assertRaises(exception.StringLengthExceeded): + self.catalog_api.create_endpoint(endpoint['id'], endpoint.copy()) + class SqlPolicy(SqlTests, test_backend.PolicyTests): pass diff --git a/tests/test_v3.py b/tests/test_v3.py index 254f14f2b7..4c7615d7df 100644 --- a/tests/test_v3.py +++ b/tests/test_v3.py @@ -42,7 +42,7 @@ class RestfulTestCase(test_content_types.RestfulTestCase): def new_endpoint_ref(self, service_id): ref = self.new_ref() - ref['interface'] = uuid.uuid4().hex + ref['interface'] = uuid.uuid4().hex[:8] ref['service_id'] = service_id ref['url'] = uuid.uuid4().hex return ref diff --git a/tests/test_v3_catalog.py b/tests/test_v3_catalog.py index 9f5bf9139a..3a90170944 100644 --- a/tests/test_v3_catalog.py +++ b/tests/test_v3_catalog.py @@ -119,6 +119,15 @@ class CatalogTestCase(test_v3.RestfulTestCase): body={'endpoint': ref}) self.assertValidEndpointResponse(r, ref) + def assertValidErrorResponse(self, response): + self.assertTrue(response.status in [400]) + + def test_create_endpoint_400(self): + """POST /endpoints""" + ref = self.new_endpoint_ref(service_id=self.service_id) + ref["region"] = "0" * 256 + self.post('/endpoints', body={'endpoint': ref}, expected_status=400) + def test_get_endpoint(self): """GET /endpoints/{endpoint_id}""" r = self.get(