diff --git a/swift3/exception.py b/swift3/exception.py new file mode 100644 index 00000000..2d0730c7 --- /dev/null +++ b/swift3/exception.py @@ -0,0 +1,22 @@ +# Copyright (c) 2014 OpenStack Foundation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class S3Exception(Exception): + pass + + +class NotS3Request(S3Exception): + pass diff --git a/swift3/middleware.py b/swift3/middleware.py index 450e691f..7650e187 100644 --- a/swift3/middleware.py +++ b/swift3/middleware.py @@ -53,14 +53,10 @@ following for an SAIO setup:: """ from urllib import quote -import base64 from simplejson import loads -import email.utils -import datetime import re from swift.common.utils import get_logger -from swift.common.swob import Request from swift.common.http import HTTP_OK, HTTP_CREATED, HTTP_ACCEPTED, \ HTTP_NO_CONTENT, HTTP_UNAUTHORIZED, HTTP_FORBIDDEN, HTTP_NOT_FOUND, \ HTTP_CONFLICT, HTTP_UNPROCESSABLE_ENTITY, is_success, \ @@ -68,11 +64,12 @@ from swift.common.http import HTTP_OK, HTTP_CREATED, HTTP_ACCEPTED, \ from swift.common.middleware.acl import parse_acl, referrer_allowed from swift3.etree import fromstring, tostring, Element, SubElement -from swift3.response import Response, HTTPNoContent, HTTPOk, ErrorResponse, \ +from swift3.exception import NotS3Request +from swift3.request import Request +from swift3.response import HTTPNoContent, HTTPOk, ErrorResponse, \ AccessDenied, BucketAlreadyExists, BucketNotEmpty, EntityTooLarge, \ - InternalError, InvalidArgument, InvalidDigest, InvalidURI, \ - MalformedACLError, MethodNotAllowed, NoSuchBucket, NoSuchKey, \ - S3NotImplemented, RequestTimeTooSkewed, SignatureDoesNotMatch + InternalError, InvalidArgument, InvalidDigest, MalformedACLError, \ + MethodNotAllowed, NoSuchBucket, NoSuchKey, S3NotImplemented XMLNS_XSI = 'http://www.w3.org/2001/XMLSchema-instance' @@ -140,45 +137,6 @@ def get_acl(account_name, headers): return HTTPOk(body=body, content_type="text/plain") -def canonical_string(req): - """ - Canonicalize a request to a token that can be signed. - """ - amz_headers = {} - - buf = "%s\n%s\n%s\n" % (req.method, req.headers.get('Content-MD5', ''), - req.headers.get('Content-Type') or '') - - for amz_header in sorted((key.lower() for key in req.headers - if key.lower().startswith('x-amz-'))): - amz_headers[amz_header] = req.headers[amz_header] - - if 'x-amz-date' in amz_headers: - buf += "\n" - elif 'Date' in req.headers: - buf += "%s\n" % req.headers['Date'] - - for k in sorted(key.lower() for key in amz_headers): - buf += "%s:%s\n" % (k, amz_headers[k]) - - # RAW_PATH_INFO is enabled in later version than eventlet 0.9.17. - # When using older version, swift3 uses req.path of swob instead - # of it. - path = req.environ.get('RAW_PATH_INFO', req.path) - if req.query_string: - path += '?' + req.query_string - if '?' in path: - path, args = path.split('?', 1) - params = [] - for key, value in sorted(req.params.items()): - if key in ALLOWED_SUB_RESOURCES: - params.append('%s=%s' % (key, value) if value else key) - if params: - return '%s%s?%s' % (buf, path, '&'.join(params)) - - return buf + path - - def swift_acl_translate(acl, group='', user='', xml=False): """ Takes an S3 style ACL and returns a list of header/value pairs that @@ -255,21 +213,9 @@ class Controller(object): """ Base WSGI controller class for the middleware """ - def __init__(self, req, app, account_name, token, conf, - container_name=None, object_name=None, **kwargs): + def __init__(self, app, conf, **kwargs): self.app = app self.conf = conf - self.account_name = account_name - self.container_name = container_name - self.object_name = object_name - req.environ['HTTP_X_AUTH_TOKEN'] = token - if object_name: - req.path_info = '/v1/%s/%s/%s' % (account_name, container_name, - object_name) - elif container_name: - req.path_info = '/v1/%s/%s' % (account_name, container_name) - else: - req.path_info = '/v1/%s' % (account_name) class ServiceController(Controller): @@ -318,12 +264,10 @@ class BucketController(Controller): req.query_string = '' resp = req.get_response(self.app) - status = resp.status_int - headers = resp.headers - if status == HTTP_NO_CONTENT: - status = HTTP_OK + if resp.status_int == HTTP_NO_CONTENT: + resp.status_int = HTTP_OK - return Response(status=status, headers=headers, app_iter=resp.app_iter) + return resp def GET(self, req): """ @@ -351,7 +295,7 @@ class BucketController(Controller): if status in (HTTP_UNAUTHORIZED, HTTP_FORBIDDEN): raise AccessDenied() elif status == HTTP_NOT_FOUND: - raise NoSuchBucket(self.container_name) + raise NoSuchBucket(req.container_name) else: raise InternalError() @@ -367,7 +311,7 @@ class BucketController(Controller): is_truncated = 'false' SubElement(elem, 'IsTruncated').text = is_truncated SubElement(elem, 'MaxKeys').text = str(max_keys) - SubElement(elem, 'Name').text = self.container_name + SubElement(elem, 'Name').text = req.container_name for o in objects[:max_keys]: if 'subdir' not in o: @@ -377,7 +321,7 @@ class BucketController(Controller): o['last_modified'] + 'Z' SubElement(contents, 'ETag').text = o['hash'] SubElement(contents, 'Size').text = str(o['bytes']) - add_canonical_user(contents, 'Owner', self.account_name) + add_canonical_user(contents, 'Owner', req.access_key) for o in objects[:max_keys]: if 'subdir' in o: @@ -410,14 +354,6 @@ class BucketController(Controller): for header, acl in translated_acl: req.headers[header] = acl - if 'CONTENT_LENGTH' in req.environ: - try: - if req.content_length < 0: - raise InvalidArgument('Content-Length', req.content_length) - except (ValueError, TypeError): - raise InvalidArgument('Content-Length', - req.environ['CONTENT_LENGTH']) - resp = req.get_response(self.app) status = resp.status_int @@ -425,11 +361,11 @@ class BucketController(Controller): if status in (HTTP_UNAUTHORIZED, HTTP_FORBIDDEN): raise AccessDenied() elif status == HTTP_ACCEPTED: - raise BucketAlreadyExists(self.container_name) + raise BucketAlreadyExists(req.container_name) else: raise InternalError() - return HTTPOk(headers={'Location': self.container_name}) + return HTTPOk(headers={'Location': req.container_name}) def DELETE(self, req): """ @@ -442,7 +378,7 @@ class BucketController(Controller): if status in (HTTP_UNAUTHORIZED, HTTP_FORBIDDEN): raise AccessDenied() elif status == HTTP_NOT_FOUND: - raise NoSuchBucket(self.container_name) + raise NoSuchBucket(req.container_name) elif status == HTTP_CONFLICT: raise BucketNotEmpty() else: @@ -464,18 +400,16 @@ class ObjectController(Controller): def GETorHEAD(self, req): resp = req.get_response(self.app) status = resp.status_int - headers = resp.headers if req.method == 'HEAD': resp.app_iter = None if is_success(status): - return Response(status=status, headers=headers, - app_iter=resp.app_iter) + return resp elif status in (HTTP_UNAUTHORIZED, HTTP_FORBIDDEN): raise AccessDenied() elif status == HTTP_NOT_FOUND: - raise NoSuchKey(self.object_name) + raise NoSuchKey(req.object_name) else: raise InternalError() @@ -495,23 +429,6 @@ class ObjectController(Controller): """ Handle PUT Object and PUT Object (Copy) request """ - for key, value in req.environ.items(): - if key.startswith('HTTP_X_AMZ_META_'): - del req.environ[key] - req.environ['HTTP_X_OBJECT_META_' + key[16:]] = value - elif key == 'HTTP_CONTENT_MD5': - if value == '': - raise InvalidDigest() - try: - req.environ['HTTP_ETAG'] = \ - value.decode('base64').encode('hex') - except Exception: - raise InvalidDigest() - if req.environ['HTTP_ETAG'] == '': - raise SignatureDoesNotMatch() - elif key == 'HTTP_X_AMZ_COPY_SOURCE': - req.environ['HTTP_X_COPY_FROM'] = value - resp = req.get_response(self.app) status = resp.status_int @@ -519,7 +436,7 @@ class ObjectController(Controller): if status in (HTTP_UNAUTHORIZED, HTTP_FORBIDDEN): raise AccessDenied() elif status == HTTP_NOT_FOUND: - raise NoSuchBucket(self.container_name) + raise NoSuchBucket(req.container_name) elif status == HTTP_UNPROCESSABLE_ENTITY: raise InvalidDigest() elif status == HTTP_REQUEST_ENTITY_TOO_LARGE: @@ -553,7 +470,7 @@ class ObjectController(Controller): if status in (HTTP_UNAUTHORIZED, HTTP_FORBIDDEN): raise AccessDenied() elif status == HTTP_NOT_FOUND: - raise NoSuchKey(self.object_name) + raise NoSuchKey(req.object_name) else: raise InternalError() @@ -575,7 +492,7 @@ class AclController(Controller): """ Handles GET Bucket acl and GET Object acl. """ - if self.object_name: + if req.object_name: # Handle Object ACL # ACL requests need to make a HEAD call rather than GET @@ -590,11 +507,11 @@ class AclController(Controller): if is_success(status): # Method must be GET or the body wont be returned to the caller req.environ['REQUEST_METHOD'] = 'GET' - return get_acl(self.account_name, headers) + return get_acl(req.access_key, headers) elif status in (HTTP_UNAUTHORIZED, HTTP_FORBIDDEN): raise AccessDenied() elif status == HTTP_NOT_FOUND: - raise NoSuchKey(self.object_name) + raise NoSuchKey(req.object_name) else: raise InternalError() @@ -605,12 +522,12 @@ class AclController(Controller): headers = resp.headers if is_success(status): - return get_acl(self.account_name, headers) + return get_acl(req.access_key, headers) if status in (HTTP_UNAUTHORIZED, HTTP_FORBIDDEN): raise AccessDenied() elif status == HTTP_NOT_FOUND: - raise NoSuchBucket(self.container_name) + raise NoSuchBucket(req.container_name) else: raise InternalError() @@ -618,7 +535,7 @@ class AclController(Controller): """ Handles PUT Bucket acl and PUT Object acl. """ - if self.object_name: + if req.object_name: # Handle Object ACL raise S3NotImplemented() else: @@ -643,7 +560,7 @@ class AclController(Controller): else: raise InternalError() - return HTTPOk(headers={'Location': self.container_name}) + return HTTPOk(headers={'Location': req.container_name}) class LocationController(Controller): @@ -662,7 +579,7 @@ class LocationController(Controller): if status in (HTTP_UNAUTHORIZED, HTTP_FORBIDDEN): raise AccessDenied() elif status == HTTP_NOT_FOUND: - raise NoSuchBucket(self.container_name) + raise NoSuchBucket(req.container_name) else: raise InternalError() @@ -694,7 +611,7 @@ class LoggingStatusController(Controller): if status in (HTTP_UNAUTHORIZED, HTTP_FORBIDDEN): raise AccessDenied() elif status == HTTP_NOT_FOUND: - raise NoSuchBucket(self.container_name) + raise NoSuchBucket(req.container_name) else: raise InternalError() @@ -741,9 +658,9 @@ class MultiObjectDeleteController(Controller): sub_req.query_string = '' sub_req.content_length = 0 sub_req.method = 'DELETE' - controller = ObjectController(sub_req, self.app, self.account_name, - req.environ['HTTP_X_AUTH_TOKEN'], - self.container_name, key) + sub_req.object_name = key + + controller = ObjectController(self.app, self.conf) try: controller.DELETE(sub_req) except NoSuchKey: @@ -777,7 +694,7 @@ class PartController(Controller): Handles Upload Part and Upload Part Copy. """ # Pass it through, the s3multi upload helper will handle it. - return self.app + return req.get_response(self.app) class UploadsController(Controller): @@ -794,14 +711,14 @@ class UploadsController(Controller): Handles List Multipart Uploads """ # Pass it through, the s3multi upload helper will handle it. - return self.app + return req.get_response(self.app) def POST(self, req): """ Handles Initiate Multipart Upload. """ # Pass it through, the s3multi upload helper will handle it. - return self.app + return req.get_response(self.app) class UploadController(Controller): @@ -819,21 +736,21 @@ class UploadController(Controller): Handles List Parts. """ # Pass it through, the s3multi upload helper will handle it. - return self.app + return req.get_response(self.app) def DELETE(self, req): """ Handles Abort Multipart Upload. """ # Pass it through, the s3multi upload helper will handle it. - return self.app + return req.get_response(self.app) def POST(self, req): """ Handles Complete Multipart Upload. """ # Pass it through, the s3multi upload helper will handle it. - return self.app + return req.get_response(self.app) class VersioningController(Controller): @@ -856,7 +773,7 @@ class VersioningController(Controller): if status in (HTTP_UNAUTHORIZED, HTTP_FORBIDDEN): raise AccessDenied() elif status == HTTP_NOT_FOUND: - raise NoSuchBucket(self.container_name) + raise NoSuchBucket(req.container_name) else: raise InternalError() @@ -880,41 +797,12 @@ class Swift3Middleware(object): self.conf = conf self.logger = get_logger(self.conf, log_route='swift3') - def get_controller(self, req): - container, obj = req.split_path(0, 2, True) - d = dict(container_name=container, object_name=obj) - - if 'acl' in req.params: - return AclController, d - if 'delete' in req.params: - return MultiObjectDeleteController, d - if 'location' in req.params: - return LocationController, d - if 'logging' in req.params: - return LoggingStatusController, d - if 'partNumber' in req.params: - return PartController, d - if 'uploadId' in req.params: - return UploadController, d - if 'uploads' in req.params: - return UploadsController, d - if 'versioning' in req.params: - return VersioningController, d - - if container and obj: - if req.method == 'POST': - if 'uploads' in req.params or 'uploadId' in req.params: - return BucketController, d - return ObjectController, d - elif container: - return BucketController, d - - return ServiceController, d - def __call__(self, env, start_response): - req = Request(env) try: + req = Request(env) resp = self.handle_request(req) + except NotS3Request: + resp = self.app except ErrorResponse as err_resp: if isinstance(err_resp, InternalError): self.logger.exception(err_resp) @@ -928,71 +816,7 @@ class Swift3Middleware(object): self.logger.debug('Calling Swift3 Middleware') self.logger.debug(req.__dict__) - if 'AWSAccessKeyId' in req.params: - try: - req.headers['Date'] = req.params['Expires'] - req.headers['Authorization'] = \ - 'AWS %(AWSAccessKeyId)s:%(Signature)s' % req.params - except KeyError: - raise AccessDenied() - - if 'Authorization' not in req.headers: - return self.app - - try: - keyword, info = req.headers['Authorization'].split(' ') - except Exception: - raise AccessDenied() - - if keyword != 'AWS': - raise AccessDenied() - - try: - account, signature = info.rsplit(':', 1) - except Exception: - err_msg = 'AWS authorization header is invalid. ' \ - 'Expected AwsAccessKeyId:signature' - raise InvalidArgument('Authorization', - req.headers['Authorization'], err_msg) - - try: - controller, path_parts = self.get_controller(req) - except ValueError: - raise InvalidURI(req.path) - - if 'Date' in req.headers: - now = datetime.datetime.utcnow() - date = email.utils.parsedate(req.headers['Date']) - if 'Expires' in req.params: - try: - d = email.utils.formatdate(float(req.params['Expires'])) - except ValueError: - raise AccessDenied() - - # check expiration - expdate = email.utils.parsedate(d) - ex = datetime.datetime(*expdate[0:6]) - if now > ex: - raise AccessDenied() - elif date is not None: - epoch = datetime.datetime(1970, 1, 1, 0, 0, 0, 0) - - d1 = datetime.datetime(*date[0:6]) - if d1 < epoch: - raise AccessDenied() - - # If the standard date is too far ahead or behind, it is an - # error - delta = datetime.timedelta(seconds=60 * 5) - if abs(d1 - now) > delta: - raise RequestTimeTooSkewed() - else: - raise AccessDenied() - - token = base64.urlsafe_b64encode(canonical_string(req)) - - controller = controller(req, self.app, account, token, self.conf, - **path_parts) + controller = req.controller(self.app, self.conf) if hasattr(controller, req.method): res = getattr(controller, req.method)(req) diff --git a/swift3/request.py b/swift3/request.py new file mode 100644 index 00000000..1c504c6e --- /dev/null +++ b/swift3/request.py @@ -0,0 +1,234 @@ +# Copyright (c) 2014 OpenStack Foundation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from urllib import quote +import base64 +import email.utils +import datetime + +from swift.common import swob + +from swift3.response import AccessDenied, InvalidArgument, InvalidDigest, \ + RequestTimeTooSkewed, Response, SignatureDoesNotMatch +from swift3.exception import NotS3Request + +# List of sub-resources that must be maintained as part of the HMAC +# signature string. +ALLOWED_SUB_RESOURCES = sorted([ + 'acl', 'delete', 'lifecycle', 'location', 'logging', 'notification', + 'partNumber', 'policy', 'requestPayment', 'torrent', 'uploads', 'uploadId', + 'versionId', 'versioning', 'versions ', 'website' +]) + + +class Request(swob.Request): + """ + S3 request object. + """ + def __init__(self, env): + swob.Request.__init__(self, env) + + self.access_key, self.signature = self._parse_authorization() + self.container_name, self.object_name = self.split_path(0, 2, True) + self._validate_headers() + self.token = base64.urlsafe_b64encode(self._canonical_string()) + + def _parse_authorization(self): + if 'AWSAccessKeyId' in self.params: + try: + self.headers['Date'] = self.params['Expires'] + self.headers['Authorization'] = \ + 'AWS %(AWSAccessKeyId)s:%(Signature)s' % self.params + except KeyError: + raise AccessDenied() + + if 'Authorization' not in self.headers: + raise NotS3Request + + try: + keyword, info = self.headers['Authorization'].split(' ') + except Exception: + raise AccessDenied() + + if keyword != 'AWS': + raise AccessDenied() + + try: + access_key, signature = info.rsplit(':', 1) + except Exception: + err_msg = 'AWS authorization header is invalid. ' \ + 'Expected AwsAccessKeyId:signature' + raise InvalidArgument('Authorization', + self.headers['Authorization'], err_msg) + + return access_key, signature + + def _validate_headers(self): + if 'CONTENT_LENGTH' in self.environ: + try: + if self.content_length < 0: + raise InvalidArgument('Content-Length', + self.content_length) + except (ValueError, TypeError): + raise InvalidArgument('Content-Length', + self.environ['CONTENT_LENGTH']) + + if 'Date' in self.headers: + now = datetime.datetime.utcnow() + date = email.utils.parsedate(self.headers['Date']) + if 'Expires' in self.params: + try: + d = email.utils.formatdate(float(self.params['Expires'])) + except ValueError: + raise AccessDenied() + + # check expiration + expdate = email.utils.parsedate(d) + ex = datetime.datetime(*expdate[0:6]) + if now > ex: + raise AccessDenied('Request has expired') + elif date is not None: + epoch = datetime.datetime(1970, 1, 1, 0, 0, 0, 0) + + d1 = datetime.datetime(*date[0:6]) + if d1 < epoch: + raise AccessDenied() + + # If the standard date is too far ahead or behind, it is an + # error + delta = datetime.timedelta(seconds=60 * 5) + if abs(d1 - now) > delta: + raise RequestTimeTooSkewed() + else: + raise AccessDenied() + + if 'Content-MD5' in self.headers: + value = self.headers['Content-MD5'] + if value == '': + raise InvalidDigest() + try: + self.headers['ETag'] = value.decode('base64').encode('hex') + except Exception: + raise InvalidDigest() + if self.headers['ETag'] == '': + raise SignatureDoesNotMatch() + + def _canonical_string(self): + """ + Canonicalize a request to a token that can be signed. + """ + amz_headers = {} + + buf = "%s\n%s\n%s\n" % (self.method, + self.headers.get('Content-MD5', ''), + self.headers.get('Content-Type') or '') + + for amz_header in sorted((key.lower() for key in self.headers + if key.lower().startswith('x-amz-'))): + amz_headers[amz_header] = self.headers[amz_header] + + if 'x-amz-date' in amz_headers: + buf += "\n" + elif 'Date' in self.headers: + buf += "%s\n" % self.headers['Date'] + + for k in sorted(key.lower() for key in amz_headers): + buf += "%s:%s\n" % (k, amz_headers[k]) + + path = self.environ.get('RAW_PATH_INFO', self.path) + if self.query_string: + path += '?' + self.query_string + if '?' in path: + path, args = path.split('?', 1) + params = [] + for key, value in sorted(self.params.items()): + if key in ALLOWED_SUB_RESOURCES: + params.append('%s=%s' % (key, value) if value else key) + if params: + return '%s%s?%s' % (buf, path, '&'.join(params)) + + return buf + path + + @property + def controller(self): + from swift3.middleware import ServiceController, BucketController, \ + ObjectController, AclController, MultiObjectDeleteController, \ + LocationController, LoggingStatusController, PartController, \ + UploadController, UploadsController, VersioningController + + if 'acl' in self.params: + return AclController + if 'delete' in self.params: + return MultiObjectDeleteController + if 'location' in self.params: + return LocationController + if 'logging' in self.params: + return LoggingStatusController + if 'partNumber' in self.params: + return PartController + if 'uploadId' in self.params: + return UploadController + if 'uploads' in self.params: + return UploadsController + if 'versioning' in self.params: + return VersioningController + + if self.container_name and self.object_name: + return ObjectController + elif self.container_name: + return BucketController + + return ServiceController + + def to_swift_req(self): + """ + Create a Swift request based on this request's environment. + """ + env = self.environ.copy() + + for key in env: + if key.startswith('HTTP_X_AMZ_META_'): + env['HTTP_X_OBJECT_META_' + key[16:]] = env[key] + del env[key] + + if key == 'HTTP_X_AMZ_COPY_SOURCE': + env['HTTP_X_COPY_FROM'] = env[key] + del env[key] + + env['swift.source'] = 'S3' + env['HTTP_X_AUTH_TOKEN'] = self.token + + if self.object_name: + path = '/v1/%s/%s/%s' % (self.access_key, self.container_name, + self.object_name) + elif self.container_name: + path = '/v1/%s/%s' % (self.access_key, self.container_name) + else: + path = '/v1/%s' % (self.access_key) + env['PATH_INFO'] = path + + env['QUERY_STRING'] = self.query_string + + return swob.Request.blank(quote(path), environ=env) + + def get_response(self, app): + """ + Calls the application with this request's environment. Returns a + Response object that wraps up the application's result. + """ + sw_req = self.to_swift_req() + sw_resp = sw_req.get_response(app) + + return Response.from_swift_resp(sw_resp) diff --git a/swift3/response.py b/swift3/response.py index e8163c59..60e86bd1 100644 --- a/swift3/response.py +++ b/swift3/response.py @@ -90,6 +90,25 @@ class Response(swob.Response): self.headers = headers + @classmethod + def from_swift_resp(cls, sw_resp): + """ + Create a new S3 response object based on the given Swift response. + """ + if sw_resp.app_iter: + body = None + app_iter = sw_resp.app_iter + else: + body = sw_resp.body + app_iter = None + + resp = Response(status=sw_resp.status, headers=sw_resp.headers, + request=sw_resp.request, body=body, app_iter=app_iter, + conditional_response=sw_resp.conditional_response) + resp.environ.update(sw_resp.environ) + + return resp + class StatusMap(object): """ diff --git a/swift3/test/unit/test_swift3.py b/swift3/test/unit/test_swift3.py index d5d64e19..c63af8dc 100644 --- a/swift3/test/unit/test_swift3.py +++ b/swift3/test/unit/test_swift3.py @@ -28,6 +28,7 @@ from swift.common.swob import Request from swift3 import middleware as swift3 from swift3.test.unit.helpers import FakeSwift from swift3.etree import fromstring, tostring, Element, SubElement +from swift3.request import Request as S3Request XMLNS_XSI = 'http://www.w3.org/2001/XMLSchema-instance' @@ -187,7 +188,7 @@ class TestSwift3(unittest.TestCase): environ={'REQUEST_METHOD': 'GET'}, headers={'Authorization': 'AWS test:tester:hmac'}) status, headers, body = self.call_swift3(req) - raw_path_info = "/v1/AUTH_test/%s/%s" % (bucket_name, object_name) + raw_path_info = "/%s/%s" % (bucket_name, object_name) path_info = req.environ['PATH_INFO'] self.assertEquals(path_info, unquote(raw_path_info)) self.assertEquals(req.path, quote(path_info)) @@ -569,10 +570,24 @@ class TestSwift3(unittest.TestCase): The hashes here were generated by running the same requests against boto.utils.canonical_string """ + def canonical_string(path, headers): + if '?' in path: + path, query_string = path.split('?', 1) + else: + query_string = '' + + req = S3Request({ + 'REQUEST_METHOD': 'GET', + 'PATH_INFO': path, + 'QUERY_STRING': query_string, + 'HTTP_AUTHORIZATION': 'AWS X:Y:Z', + }) + req.headers.update(headers) + return req._canonical_string() + def verify(hash, path, headers): - req = Request.blank(path, headers=headers) - self.assertEquals(hash, hashlib.md5( - swift3.canonical_string(req)).hexdigest()) + s = canonical_string(path, headers) + self.assertEquals(hash, hashlib.md5(s).hexdigest()) verify('6dd08c75e42190a1ce9468d1fd2eb787', '/bucket/object', {'Content-Type': 'text/plain', 'X-Amz-Something': 'test', @@ -586,7 +601,7 @@ class TestSwift3(unittest.TestCase): verify('be01bd15d8d47f9fe5e2d9248cc6f180', '/bucket/object', {}) - verify('8d28cc4b8322211f6cc003256cd9439e', 'bucket/object', + verify('e9ec7dca45eef3e2c7276af23135e896', '/bucket/object', {'Content-MD5': 'somestuff'}) verify('a822deb31213ad09af37b5a7fe59e55e', '/bucket/object?acl', {}) @@ -608,16 +623,16 @@ class TestSwift3(unittest.TestCase): {'Content-Type': None, 'Date': 'Tue, 12 Jul 2011 10:52:57 +0000'}) - req1 = Request.blank('/', headers= - {'Content-Type': None, 'X-Amz-Something': 'test'}) - req2 = Request.blank('/', headers= - {'Content-Type': '', 'X-Amz-Something': 'test'}) - req3 = Request.blank('/', headers={'X-Amz-Something': 'test'}) + str1 = canonical_string('/', headers= + {'Content-Type': None, + 'X-Amz-Something': 'test'}) + str2 = canonical_string('/', headers= + {'Content-Type': '', + 'X-Amz-Something': 'test'}) + str3 = canonical_string('/', headers={'X-Amz-Something': 'test'}) - self.assertEquals(swift3.canonical_string(req1), - swift3.canonical_string(req2)) - self.assertEquals(swift3.canonical_string(req2), - swift3.canonical_string(req3)) + self.assertEquals(str1, str2) + self.assertEquals(str2, str3) def test_signed_urls_expired(self): expire = '1000000000' @@ -668,8 +683,9 @@ class TestSwift3(unittest.TestCase): environ={'REQUEST_METHOD': 'PUT'}) req.headers['Authorization'] = 'AWS test:tester:hmac' status, headers, body = self.call_swift3(req) + _, _, headers = self.swift.calls_with_headers[-1] self.assertEquals(base64.urlsafe_b64decode( - req.headers['X-Auth-Token']), + headers['X-Auth-Token']), 'PUT\n\n\n/bucket/object?partNumber=1&uploadId=123456789abcdef') def test_xml_namespace(self):