Add ParseResultBytes

This gives us API compatibility with the standard library (for the most part).
This commit is contained in:
Ian Cordasco 2015-07-05 10:34:48 -05:00
parent e981295fee
commit 325d71cc94
3 changed files with 237 additions and 46 deletions

View File

@ -14,15 +14,57 @@
# limitations under the License.
from collections import namedtuple
from . import compat
from . import exceptions
from . import normalizers
from . import uri
__all__ = ('ParseResult', 'ParseResultBytes')
PARSED_COMPONENTS = ('scheme', 'userinfo', 'host', 'port', 'path', 'query',
'fragment')
class ParseResult(namedtuple('ParseResult', PARSED_COMPONENTS)):
class ParseResultMixin(object):
def _generate_authority(self, attributes):
# I swear I did not align the comparisons below. That's just how they
# happened to align based on pep8 and attribute lengths.
userinfo, host, port = (attributes[p]
for p in ('userinfo', 'host', 'port'))
if (self.userinfo != userinfo or
self.host != host or
self.port != port):
if port:
port = '{0}'.format(port)
return normalizers.normalize_authority(
(compat.to_str(userinfo, self.encoding),
compat.to_str(host, self.encoding),
port)
)
return self.authority
def geturl(self):
"""Standard library shim to the unsplit method."""
return self.unsplit()
@property
def hostname(self):
"""Standard library shim for the host portion of the URI."""
return self.host
@property
def netloc(self):
"""Standard library shim for the authority portion of the URI."""
return self.authority
@property
def params(self):
"""Standard library shim for the query portion of the URI."""
return self.query
class ParseResult(namedtuple('ParseResult', PARSED_COMPONENTS),
ParseResultMixin):
slots = ()
def __new__(cls, scheme, userinfo, host, port, path, query, fragment,
@ -85,19 +127,6 @@ class ParseResult(namedtuple('ParseResult', PARSED_COMPONENTS)):
"""Normalized authority generated from the subauthority parts."""
return self.reference.authority
def _generate_authority(self, attributes):
# I swear I did not align the comparisons below. That's just how they
# happened to align based on pep8 and attribute lengths.
userinfo, host, port = (attributes[p]
for p in ('userinfo', 'host', 'port'))
if (self.userinfo != userinfo or
self.host != host or
self.port != port):
if port:
port = '{0}'.format(port)
return normalizers.normalize_authority((userinfo, host, port))
return self.authority
def copy_with(self, scheme=None, userinfo=None, host=None, port=None,
path=None, query=None, fragment=None):
attributes = zip(PARSED_COMPONENTS,
@ -108,31 +137,24 @@ class ParseResult(namedtuple('ParseResult', PARSED_COMPONENTS)):
value = getattr(self, name)
attrs_dict[name] = value
authority = self._generate_authority(attrs_dict)
ref = self.reference.copy_with(scheme=attrs_dict.get('scheme'),
ref = self.reference.copy_with(scheme=attrs_dict['scheme'],
authority=authority,
path=attrs_dict.get('path'),
query=attrs_dict.get('query'),
fragment=attrs_dict.get('fragment'))
return ParseResult(uri_ref=ref, **attrs_dict)
path=attrs_dict['path'],
query=attrs_dict['query'],
fragment=attrs_dict['fragment'])
return ParseResult(uri_ref=ref, encoding=self.encoding, **attrs_dict)
def geturl(self):
"""Standard library shim to the unsplit method."""
return self.unsplit()
@property
def hostname(self):
"""Standard library shim for the host portion of the URI."""
return self.host
@property
def netloc(self):
"""Standard library shim for the authority portion of the URI."""
return self.authority
@property
def params(self):
"""Standard library shim for the query portion of the URI."""
return self.query
def encode(self, encoding=None):
encoding = encoding or self.encoding
attrs = dict(
zip(PARSED_COMPONENTS,
(attr.encode(encoding) if hasattr(attr, 'encode') else attr
for attr in self)))
return ParseResultBytes(
uri_ref=self.reference,
encoding=encoding,
**attrs
)
def unsplit(self, use_idna=False):
"""Create a URI string from the components.
@ -141,13 +163,119 @@ class ParseResult(namedtuple('ParseResult', PARSED_COMPONENTS)):
:rtype: str
"""
parse_result = self
if use_idna:
if use_idna and self.host:
hostbytes = self.host.encode('idna')
host = hostbytes.decode(self.encoding)
parse_result = self.copy_with(host=host)
return parse_result.reference.unsplit()
class ParseResultBytes(namedtuple('ParseResultBytes', PARSED_COMPONENTS),
ParseResultMixin):
def __new__(cls, scheme, userinfo, host, port, path, query, fragment,
uri_ref, encoding='utf-8'):
parse_result = super(ParseResultBytes, cls).__new__(
cls,
scheme or None,
userinfo or None,
host,
port or None,
path or None,
query or None,
fragment or None)
parse_result.encoding = encoding
parse_result.reference = uri_ref
return parse_result
@classmethod
def from_string(cls, uri_string, encoding='utf-8', strict=True):
"""Parse a URI from the given unicode URI string.
:param str uri_string: Unicode URI to be parsed into a reference.
:param str encoding: The encoding of the string provided
:param bool strict: Parse strictly according to :rfc:`3986` if True.
If False, parse similarly to the standard library's urlparse
function.
:returns: :class:`ParseResultBytes` or subclass thereof
"""
reference = uri.URIReference.from_string(uri_string, encoding)
try:
subauthority = reference.authority_info()
except exceptions.InvalidAuthority:
if strict:
raise
userinfo, host, port = split_authority(reference.authority)
else:
# Thanks to Richard Barrell for this idea:
# https://twitter.com/0x2ba22e11/status/617338811975139328
userinfo, host, port = (subauthority.get(p)
for p in ('userinfo', 'host', 'port'))
if port:
try:
port = int(port)
except ValueError:
raise exceptions.InvalidPort(port)
to_bytes = compat.to_bytes
return cls(scheme=to_bytes(reference.scheme, encoding),
userinfo=to_bytes(userinfo, encoding),
host=to_bytes(host, encoding),
port=port,
path=to_bytes(reference.path, encoding),
query=to_bytes(reference.query, encoding),
fragment=to_bytes(reference.fragment, encoding),
uri_ref=reference,
encoding=encoding)
@property
def authority(self):
"""Normalized authority generated from the subauthority parts."""
return self.reference.authority.encode(self.encoding)
def copy_with(self, scheme=None, userinfo=None, host=None, port=None,
path=None, query=None, fragment=None):
attributes = zip(PARSED_COMPONENTS,
(scheme, userinfo, host, port, path, query, fragment))
attrs_dict = {}
for name, value in attributes:
if value is None:
value = getattr(self, name)
if not isinstance(value, bytes) and hasattr(value, 'encode'):
value = value.encode(self.encoding)
attrs_dict[name] = value
authority = self._generate_authority(attrs_dict)
to_str = compat.to_str
ref = self.reference.copy_with(
scheme=to_str(attrs_dict['scheme'], self.encoding),
authority=authority,
path=to_str(attrs_dict['path'], self.encoding),
query=to_str(attrs_dict['query'], self.encoding),
fragment=to_str(attrs_dict['fragment'], self.encoding)
)
return ParseResultBytes(
uri_ref=ref,
encoding=self.encoding,
**attrs_dict
)
def unsplit(self, use_idna=False):
"""Create a URI bytes object from the components.
:returns: The parsed URI reconstituted as a string.
:rtype: bytes
"""
parse_result = self
if use_idna and self.host:
# self.host is bytes, to encode to idna, we need to decode it
# first
host = self.host.decode(self.encoding)
hostbytes = host.encode('idna')
parse_result = self.copy_with(host=hostbytes)
uri = parse_result.reference.unsplit()
return uri.encode(self.encoding)
def split_authority(authority):
# Initialize our expected return values
userinfo = host = port = None

View File

@ -12,10 +12,25 @@
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import rfc3986
from rfc3986 import exceptions
from rfc3986 import parseresult as pr
import pytest
from . import base
INVALID_PORTS = ['443:80', '443:80:443', 'abcdef', 'port', '43port']
SNOWMAN = b'\xe2\x98\x83'
SNOWMAN_IDNA_HOST = 'http://xn--n3h.com'
@pytest.mark.parametrize('port', INVALID_PORTS)
def test_port_parsing(port):
with pytest.raises(exceptions.InvalidPort):
rfc3986.urlparse('https://httpbin.org:{0}/get'.format(port))
class TestParseResultParsesURIs(base.BaseTestParsesURIs):
test_class = pr.ParseResult
@ -44,3 +59,59 @@ def test_creates_a_copy_with_a_new_port(basic_uri):
uri = pr.ParseResult.from_string(basic_uri)
new_uri = uri.copy_with(port=443)
assert new_uri.port == 443
def test_parse_result_encodes_itself(uri_with_everything):
uri = pr.ParseResult.from_string(uri_with_everything)
uribytes = uri.encode()
encoding = uri.encoding
assert uri.scheme.encode(encoding) == uribytes.scheme
assert uri.userinfo.encode(encoding) == uribytes.userinfo
assert uri.host.encode(encoding) == uribytes.host
assert uri.port == uribytes.port
assert uri.path.encode(encoding) == uribytes.path
assert uri.query.encode(encoding) == uribytes.query
assert uri.fragment.encode(encoding) == uribytes.fragment
class TestParseResultBytes:
def test_handles_uri_with_everything(self, uri_with_everything):
uri = pr.ParseResultBytes.from_string(uri_with_everything)
assert uri.scheme == b'https'
assert uri.path == b'/path/to/resource'
assert uri.query == b'key=value'
assert uri.fragment == b'fragment'
assert uri.userinfo == b'user:pass'
assert uri.port == 443
assert isinstance(uri.authority, bytes) is True
def test_raises_invalid_authority_for_invalid_uris(self, invalid_uri):
with pytest.raises(exceptions.InvalidAuthority):
pr.ParseResultBytes.from_string(invalid_uri)
@pytest.mark.parametrize('port', INVALID_PORTS)
def test_raises_invalid_port_non_strict_parse(self, port):
with pytest.raises(exceptions.InvalidPort):
pr.ParseResultBytes.from_string(
'https://httpbin.org:{0}/get'.format(port),
strict=False
)
def test_copy_with_a_new_path(self, uri_with_everything):
uri = pr.ParseResultBytes.from_string(uri_with_everything)
new_uri = uri.copy_with(path=b'/parse/result/tests/are/fun')
assert new_uri.path == b'/parse/result/tests/are/fun'
def test_copy_with_a_new_unicode_path(self, uri_with_everything):
uri = pr.ParseResultBytes.from_string(uri_with_everything)
pathbytes = b'/parse/result/tests/are/fun' + SNOWMAN
new_uri = uri.copy_with(path=pathbytes.decode('utf-8'))
assert new_uri.path == (b'/parse/result/tests/are/fun' + SNOWMAN)
def test_unsplit(self):
uri = pr.ParseResultBytes.from_string(
b'http://' + SNOWMAN + b'.com/path',
strict=False
)
idna_encoded = SNOWMAN_IDNA_HOST.encode('utf-8') + b'/path'
assert uri.unsplit(use_idna=True) == idna_encoded

View File

@ -12,8 +12,6 @@ SNOWMAN_PARAMS = b'http://example.com?utf8=' + SNOWMAN
SNOWMAN_HOST = b'http://' + SNOWMAN + b'.com'
SNOWMAN_IDNA_HOST = 'http://xn--n3h.com'
INVALID_PORTS = ['443:80', '443:80:443', 'abcdef', 'port', '43port']
def test_unicode_uri():
url_bytestring = SNOWMAN_PARAMS
@ -67,9 +65,3 @@ def test_unsplit_idna_a_unicode_hostname():
def test_strict_urlparsing():
with pytest.raises(exceptions.InvalidAuthority):
parseresult.ParseResult.from_string(SNOWMAN_HOST)
@pytest.mark.parametrize('port', INVALID_PORTS)
def test_port_parsing(port):
with pytest.raises(exceptions.InvalidPort):
urlparse('https://httpbin.org:{0}/get'.format(port))