Merge pull request #337 from dpkp/gzip_context

Use context managers in gzip_encode / gzip_decode
This commit is contained in:
Dana Powers
2015-03-29 15:43:31 -07:00

View File

@@ -1,8 +1,7 @@
from io import BytesIO
import gzip import gzip
from io import BytesIO
import struct import struct
import six
from six.moves import xrange from six.moves import xrange
_XERIAL_V1_HEADER = (-126, b'S', b'N', b'A', b'P', b'P', b'Y', 0, 1, 1) _XERIAL_V1_HEADER = (-126, b'S', b'N', b'A', b'P', b'P', b'Y', 0, 1, 1)
@@ -10,9 +9,9 @@ _XERIAL_V1_FORMAT = 'bccccccBii'
try: try:
import snappy import snappy
_has_snappy = True _HAS_SNAPPY = True
except ImportError: except ImportError:
_has_snappy = False _HAS_SNAPPY = False
def has_gzip(): def has_gzip():
@@ -20,26 +19,36 @@ def has_gzip():
def has_snappy(): def has_snappy():
return _has_snappy return _HAS_SNAPPY
def gzip_encode(payload): def gzip_encode(payload):
buffer = BytesIO() with BytesIO() as buf:
handle = gzip.GzipFile(fileobj=buffer, mode="w")
handle.write(payload) # Gzip context manager introduced in python 2.6
handle.close() # so old-fashioned way until we decide to not support 2.6
buffer.seek(0) gzipper = gzip.GzipFile(fileobj=buf, mode="w")
result = buffer.read() try:
buffer.close() gzipper.write(payload)
finally:
gzipper.close()
result = buf.getvalue()
return result return result
def gzip_decode(payload): def gzip_decode(payload):
buffer = BytesIO(payload) with BytesIO(payload) as buf:
handle = gzip.GzipFile(fileobj=buffer, mode='r')
result = handle.read() # Gzip context manager introduced in python 2.6
handle.close() # so old-fashioned way until we decide to not support 2.6
buffer.close() gzipper = gzip.GzipFile(fileobj=buf, mode='r')
try:
result = gzipper.read()
finally:
gzipper.close()
return result return result
@@ -47,8 +56,8 @@ def snappy_encode(payload, xerial_compatible=False, xerial_blocksize=32 * 1024):
"""Encodes the given data with snappy if xerial_compatible is set then the """Encodes the given data with snappy if xerial_compatible is set then the
stream is encoded in a fashion compatible with the xerial snappy library stream is encoded in a fashion compatible with the xerial snappy library
The block size (xerial_blocksize) controls how frequent the blocking occurs The block size (xerial_blocksize) controls how frequent the blocking
32k is the default in the xerial library. occurs 32k is the default in the xerial library.
The format winds up being The format winds up being
+-------------+------------+--------------+------------+--------------+ +-------------+------------+--------------+------------+--------------+
@@ -63,7 +72,7 @@ def snappy_encode(payload, xerial_compatible=False, xerial_blocksize=32 * 1024):
length will always be <= blocksize. length will always be <= blocksize.
""" """
if not _has_snappy: if not has_snappy():
raise NotImplementedError("Snappy codec is not available") raise NotImplementedError("Snappy codec is not available")
if xerial_compatible: if xerial_compatible:
@@ -74,7 +83,7 @@ def snappy_encode(payload, xerial_compatible=False, xerial_blocksize=32 * 1024):
out = BytesIO() out = BytesIO()
header = b''.join([struct.pack('!' + fmt, dat) for fmt, dat header = b''.join([struct.pack('!' + fmt, dat) for fmt, dat
in zip(_XERIAL_V1_FORMAT, _XERIAL_V1_HEADER)]) in zip(_XERIAL_V1_FORMAT, _XERIAL_V1_HEADER)])
out.write(header) out.write(header)
for chunk in _chunker(): for chunk in _chunker():
@@ -113,13 +122,13 @@ def _detect_xerial_stream(payload):
""" """
if len(payload) > 16: if len(payload) > 16:
header = header = struct.unpack('!' + _XERIAL_V1_FORMAT, bytes(payload)[:16]) header = struct.unpack('!' + _XERIAL_V1_FORMAT, bytes(payload)[:16])
return header == _XERIAL_V1_HEADER return header == _XERIAL_V1_HEADER
return False return False
def snappy_decode(payload): def snappy_decode(payload):
if not _has_snappy: if not has_snappy():
raise NotImplementedError("Snappy codec is not available") raise NotImplementedError("Snappy codec is not available")
if _detect_xerial_stream(payload): if _detect_xerial_stream(payload):