# Copyright (c) 2018-2019 Hewlett Packard Enterprise Development LP # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. # See the License for the specific language governing permissions and # limitations under the License. # import logging import os import fixtures import testtools class ChannelFixture(fixtures.Fixture): def __init__(self, channel='stdout', object=None): self._channel = channel if object is None: self._object = 'sys.%s' % channel else: self._object = object def _setUp(self): string_fixture = self.useFixture(fixtures.StringStream(self._channel)) self.stream = string_fixture.stream self.useFixture( fixtures.MonkeyPatch(self._object, self.stream) ) def getvalue(self): self.stream.seek(0) return self.stream.read() class BaseTestCase(testtools.TestCase): def setUp(self): super(BaseTestCase, self).setUp() # capture stdout/stderr for tests to inspect easily self.stdout = self.useFixture(ChannelFixture('stdout')) self.stderr = self.useFixture(ChannelFixture('stderr')) self.logger = self.useFixture( fixtures.FakeLogger( level=logging.DEBUG, format="%(levelname)s:%(name)s:%(message)s" ) ) def get_testfile(self, ext): *path_parts, clsname, testname = self.id().split('.') testname = testname.replace("test_", '', 1) possible_test_fixtures = ( "%s.%s.%s.%s" % (path_parts[-1], clsname, testname, ext), "%s.%s.%s" % (path_parts[-1], clsname, ext), "%s.%s" % (path_parts[-1], ext), ) for fpath in possible_test_fixtures: testdatafile = os.path.join(*(path_parts[:-1]), "fixtures", fpath) if os.path.exists(testdatafile): return testdatafile else: self.test_logger.warn( "No file with test data found from patterns (%s) for test " "id %s" % (", ".join(possible_test_fixtures), self.id()) ) class IsOrderedSubsetOfMismatch(object): def __init__(self, subset, set): self.subset = list(subset) self.set = list(set) def describe(self): return "set %r is not an ordered subset of %r" % ( self.subset, self.set) def get_details(self): return {} class IsOrderedSubsetOf(object): """Matches if the actual matches the order of iterable.""" def __init__(self, iterable): self.iterable = iterable def __str__(self): return 'IsOrderedSubsetOf(%s)' % self.iterable def match(self, actual): iterable = iter(self.iterable) if all(item in iterable for item in actual): return None else: return IsOrderedSubsetOfMismatch(actual, self.iterable)