# Copyright (c) 2013, Kevin Greenan (kmgreen2@gmail.com) # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # # Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. THIS SOFTWARE IS # PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS # OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES # OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN # NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF # THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import random from string import ascii_letters import sys import tempfile import time import unittest import pyeclib_c from pyeclib.ec_iface import PyECLib_EC_Types from pyeclib.ec_iface import VALID_EC_TYPES class Timer: def __init__(self): self.start_time = 0 self.end_time = 0 def start(self): self.start_time = time.time() def stop(self): self.end_time = time.time() def curr_delta(self): return self.end_time - self.start_time def stop_and_return(self): self.end_time = time.time() return self.curr_delta() class TestPyECLib(unittest.TestCase): def __init__(self, *args): self.num_datas = [12, 12, 12] self.num_parities = [2, 3, 4] self.iterations = 100 # EC algorithm and config parameters self.rs_types = [(PyECLib_EC_Types.jerasure_rs_vand), (PyECLib_EC_Types.jerasure_rs_cauchy), (PyECLib_EC_Types.isa_l_rs_vand), (PyECLib_EC_Types.liberasurecode_rs_vand)] self.xor_types = [(PyECLib_EC_Types.flat_xor_hd, 12, 6, 4), (PyECLib_EC_Types.flat_xor_hd, 10, 5, 4), (PyECLib_EC_Types.flat_xor_hd, 10, 5, 3)] self.shss = [(PyECLib_EC_Types.shss, 6, 3), (PyECLib_EC_Types.shss, 10, 4), (PyECLib_EC_Types.shss, 20, 4), (PyECLib_EC_Types.shss, 11, 7)] # Input temp files for testing self.sizes = ["101-K", "202-K", "303-K"] self.files = {} self._create_tmp_files() unittest.TestCase.__init__(self, *args) def _create_tmp_files(self): """ Create the temporary files needed for testing. Use the tempfile package so that the files will be automatically removed during garbage collection. """ for size_str in self.sizes: # Determine the size of the file to create size_desc = size_str.split("-") size = int(size_desc[0]) if size_desc[1] == 'M': size *= 1000000 elif size_desc[1] == 'K': size *= 1000 # Create the dictionary of files to test with buf = ''.join(random.choice(ascii_letters) for i in range(size)) if sys.version_info >= (3,): buf = buf.encode('ascii') tmp_file = tempfile.NamedTemporaryFile('w+b') tmp_file.write(buf) self.files[size_str] = tmp_file def get_tmp_file(self, name): """ Acquire a temp file from the dictionary of pre-built, random files with the seek position to the head of the file. """ tmp_file = self.files.get(name, None) if tmp_file: tmp_file.seek(0, 0) return tmp_file def setUp(self): # Ensure that the file offset is set to the head of the file for _, tmp_file in self.files.items(): tmp_file.seek(0, 0) def tearDown(self): pass def time_encode(self, num_data, num_parity, ec_type, hd, file_size, iterations): """ :return average encode time """ timer = Timer() tsum = 0 handle = pyeclib_c.init(num_data, num_parity, ec_type, hd) whole_file_bytes = self.get_tmp_file(file_size).read() timer.start() for l in range(iterations): fragments = pyeclib_c.encode(handle, whole_file_bytes) tsum = timer.stop_and_return() return tsum / iterations def time_decode(self, num_data, num_parity, ec_type, hd, file_size, iterations): """ :return 2-tuple, (success, average decode time) """ timer = Timer() tsum = 0 handle = pyeclib_c.init(num_data, num_parity, ec_type, hd) whole_file_bytes = self.get_tmp_file(file_size).read() success = True fragments = pyeclib_c.encode(handle, whole_file_bytes) orig_fragments = fragments[:] for i in range(iterations): missing_idxs = [] num_missing = hd - 1 for j in range(num_missing): num_frags_left = len(fragments) idx = random.randint(0, num_frags_left - 1) fragments.pop(idx) timer.start() decoded_file_bytes = pyeclib_c.decode(handle, fragments, len(fragments[0])) tsum += timer.stop_and_return() fragments = orig_fragments[:] if whole_file_bytes != decoded_file_bytes: success = False return success, tsum / iterations def time_range_decode(self, num_data, num_parity, ec_type, hd, file_size, iterations): """ :return 2-tuple, (success, average decode time) """ timer = Timer() tsum = 0 handle = pyeclib_c.init(num_data, num_parity, ec_type, hd) whole_file_bytes = self.get_tmp_file(file_size).read() success = True begins = [int(random.randint(0, len(whole_file_bytes) - 1)) for i in range(3)] ends = [int(random.randint(begins[i], len(whole_file_bytes))) for i in range(3)] ranges = list(zip(begins, ends)) fragments = pyeclib_c.encode(handle, whole_file_bytes) orig_fragments = fragments[:] for i in range(iterations): missing_idxs = [] num_missing = hd - 1 for j in range(num_missing): num_frags_left = len(fragments) idx = random.randint(0, num_frags_left - 1) fragments.pop(idx) timer.start() decoded_file_ranges = pyeclib_c.decode(handle, fragments, len(fragments[0]), ranges) tsum += timer.stop_and_return() fragments = orig_fragments[:] range_offset = 0 for r in ranges: if whole_file_bytes[r[0]:r[1]+1] != decoded_file_ranges[range_offset]: success = False range_offset += 1 return success, tsum / iterations def time_reconstruct(self, num_data, num_parity, ec_type, hd, file_size, iterations): """ :return 2-tuple, (success, average reconstruct time) """ timer = Timer() tsum = 0 handle = pyeclib_c.init(num_data, num_parity, ec_type, hd) whole_file_bytes = self.get_tmp_file(file_size).read() success = True orig_fragments = pyeclib_c.encode(handle, whole_file_bytes) for i in range(iterations): fragments = orig_fragments[:] num_missing = 1 missing_idxs = [] for j in range(num_missing): num_frags_left = len(fragments) idx = random.randint(0, num_frags_left - 1) while idx in missing_idxs: idx = random.randint(0, num_frags_left - 1) missing_idxs.append(idx) fragments.pop(idx) timer.start() reconstructed_fragment = pyeclib_c.reconstruct(handle, fragments, len(fragments[0]), missing_idxs[0]) tsum += timer.stop_and_return() if orig_fragments[missing_idxs[0]] != reconstructed_fragment: success = False # Output the fragments for debugging with open("orig_fragments", "wb") as fd_orig: fd_orig.write(orig_fragments[missing_idxs[0]]) with open("decoded_fragments", "wb") as fd_decoded: fd_decoded.write(reconstructed_fragment) print(("Fragment %d was not reconstructed!!!" % missing_idxs[0])) sys.exit(2) return success, tsum / iterations def get_throughput(self, avg_time, size_str): size_desc = size_str.split("-") size = float(size_desc[0]) if size_desc[1] == 'M': throughput = size / avg_time elif size_desc[1] == 'K': throughput = (size / 1000.0) / avg_time return format(throughput, '.10g') def test_xor_code(self): if "flat_xor_hd_3" not in VALID_EC_TYPES: print("xor backend is not available in your enviromnet, skipping test") return for (ec_type, k, m, hd) in self.xor_types: print(("\nRunning tests for flat_xor_hd k=%d, m=%d, hd=%d" % (k, m, hd))) for size_str in self.sizes: avg_time = self.time_encode(k, m, ec_type.value, hd, size_str, self.iterations) print("Encode (%s): %s" % (size_str, self.get_throughput(avg_time, size_str))) for size_str in self.sizes: success, avg_time = self.time_decode(k, m, ec_type.value, hd, size_str, self.iterations) self.assertTrue(success) print("Decode (%s): %s" % (size_str, self.get_throughput(avg_time, size_str))) for size_str in self.sizes: success, avg_time = self.time_reconstruct(k, m, ec_type.value, hd, size_str, self.iterations) self.assertTrue(success) print("Reconstruct (%s): %s" % (size_str, self.get_throughput(avg_time, size_str))) def test_shss(self): if "shss" not in VALID_EC_TYPES: print("shss backend is not available in your enviromnet, skipping test") return for (ec_type, k, m) in self.shss: print(("\nRunning tests for %s k=%d, m=%d" % (ec_type, k, m))) success = self._test_get_required_fragments(k, m, ec_type) self.assertTrue(success) for size_str in self.sizes: avg_time = self.time_encode(k, m, ec_type.value, 0, size_str, self.iterations) print("Encode (%s): %s" % (size_str, self.get_throughput(avg_time, size_str))) for size_str in self.sizes: success, avg_time = self.time_decode(k, m, ec_type.value, 0, size_str, self.iterations) self.assertTrue(success) print("Decode (%s): %s" % (size_str, self.get_throughput(avg_time, size_str))) for size_str in self.sizes: success, avg_time = self.time_reconstruct(k, m, ec_type.value, 0, size_str, self.iterations) self.assertTrue(success) print("Reconstruct (%s): %s" % (size_str, self.get_throughput(avg_time, size_str))) def _test_get_required_fragments(self, num_data, num_parity, ec_type): """ :return boolean, True if all tests passed """ handle = pyeclib_c.init(num_data, num_parity, ec_type.value) success = True # # MDS codes need any k fragments # if ec_type in ["jerasure_rs_vand", "jerasure_rs_cauchy", "liberasurecode_rs_vand"]: expected_fragments = [i for i in range(num_data + num_parity)] missing_fragments = [] # # Remove between 1 and num_parity # for i in range(random.randint(0, num_parity - 1)): missing_fragment = random.sample(expected_fragments, 1)[0] missing_fragments.append(missing_fragment) expected_fragments.remove(missing_fragment) expected_fragments = expected_fragments[:num_data] required_fragments = pyeclib_c.get_required_fragments( handle, missing_fragments, []) if expected_fragments != required_fragments: success = False print(("Unexpected required fragments list " "(exp != req): %s != %s" % (expected_fragments, required_fragments))) return success def test_codes(self): for ec_type in self.rs_types: if ec_type.name not in VALID_EC_TYPES: print("%s backend is not available in your enviromnet, skipping test" % ec_type.name) continue print(("\nRunning tests for %s" % (ec_type))) for i in range(len(self.num_datas)): success = self._test_get_required_fragments(self.num_datas[i], self.num_parities[i], ec_type) self.assertTrue(success) for i in range(len(self.num_datas)): for size_str in self.sizes: avg_time = self.time_encode(self.num_datas[i], self.num_parities[i], ec_type.value, self.num_parities[i] + 1, size_str, self.iterations) print(("Encode (%s): %s" % (size_str, self.get_throughput(avg_time, size_str)))) for i in range(len(self.num_datas)): for size_str in self.sizes: success, avg_time = self.time_decode(self.num_datas[i], self.num_parities[i], ec_type.value, self.num_parities[i] + 1, size_str, self.iterations) self.assertTrue(success) print(("Decode (%s): %s" % (size_str, self.get_throughput(avg_time, size_str)))) for i in range(len(self.num_datas)): for size_str in self.sizes: success, avg_time = self.time_range_decode(self.num_datas[i], self.num_parities[i], ec_type.value, self.num_parities[i] + 1, size_str, self.iterations) self.assertTrue(success) print(("Range Decode (%s): %s" % (size_str, self.get_throughput(avg_time, size_str)))) for i in range(len(self.num_datas)): for size_str in self.sizes: success, avg_time = self.time_reconstruct(self.num_datas[i], self.num_parities[i], ec_type.value, self.num_parities[i] + 1, size_str, self.iterations) self.assertTrue(success) print(("Reconstruct (%s): %s" % (size_str, self.get_throughput(avg_time, size_str)))) if __name__ == "__main__": unittest.main()