Improve landscape.io ratings by reducing ddt complexity

This commit is contained in:
Carles Barrobés
2015-01-23 12:13:51 +01:00
parent 2b16b287a3
commit 08aacf004f
4 changed files with 80 additions and 68 deletions

110
ddt.py
View File

@@ -1,3 +1,10 @@
# -*- coding: utf-8 -*-
# This file is a part of DDT (https://github.com/txels/ddt)
# Copyright 2012-2015 Carles Barrobés and DDT contributors
# For the exact contribution history, see the git revision log.
# DDT is licensed under the MIT License, included in
# https://github.com/txels/ddt/blob/master/LICENSE.md
import inspect
import json
import os
@@ -129,6 +136,55 @@ def mk_test_name(name, value, index=0):
return re.sub('\W|^(?=\d)', '_', test_name)
def feed_data(func, new_name, *args, **kwargs):
"""
This internal method decorator feeds the test data item to the test.
"""
@wraps(func)
def wrapper(self):
return func(self, *args, **kwargs)
wrapper.__name__ = new_name
return wrapper
def add_test(cls, test_name, func, *args, **kwargs):
"""
Add a test case to this class.
The test will be based on an existing function but will give it a new
name.
"""
setattr(cls, test_name, feed_data(func, test_name, *args, **kwargs))
def process_file_data(cls, name, func, file_attr):
"""
Process the parameter in the `file_data` decorator.
"""
cls_path = os.path.abspath(inspect.getsourcefile(cls))
data_file_path = os.path.join(os.path.dirname(cls_path), file_attr)
def _raise_ve(*args): # pylint: disable-msg=W0613
raise ValueError("%s does not exist" % file_attr)
if os.path.exists(data_file_path) is False:
test_name = mk_test_name(name, "error")
add_test(cls, test_name, _raise_ve, None)
else:
data = json.loads(open(data_file_path).read())
for i, elem in enumerate(data):
if isinstance(data, dict):
key, value = elem, data[elem]
test_name = mk_test_name(name, key, i)
elif isinstance(data, list):
value = elem
test_name = mk_test_name(name, value, i)
add_test(cls, test_name, func, value)
def ddt(cls):
"""
Class decorator for subclasses of ``unittest.TestCase``.
@@ -153,67 +209,21 @@ def ddt(cls):
from the ``data`` key.
"""
def feed_data(func, new_name, *args, **kwargs):
"""
This internal method decorator feeds the test data item to the test.
"""
@wraps(func)
def wrapper(self):
return func(self, *args, **kwargs)
wrapper.__name__ = new_name
return wrapper
def add_test(test_name, func, *args, **kwargs):
"""
Add a test case to this class.
The test will be based on an existing function but will give it a new
name.
"""
setattr(cls, test_name, feed_data(func, test_name, *args, **kwargs))
def process_file_data(name, func, file_attr):
"""
Process the parameter in the `file_data` decorator.
"""
cls_path = os.path.abspath(inspect.getsourcefile(cls))
data_file_path = os.path.join(os.path.dirname(cls_path), file_attr)
def _raise_ve(*args):
raise ValueError("%s does not exist" % file_attr)
if os.path.exists(data_file_path) is False:
test_name = mk_test_name(name, "error")
add_test(test_name, _raise_ve, None)
else:
data = json.loads(open(data_file_path).read())
for i, elem in enumerate(data):
if isinstance(data, dict):
key, value = elem, data[elem]
test_name = mk_test_name(name, key, i)
elif isinstance(data, list):
value = elem
test_name = mk_test_name(name, value, i)
add_test(test_name, func, value)
for name, func in list(cls.__dict__.items()):
if hasattr(func, DATA_ATTR):
for i, v in enumerate(getattr(func, DATA_ATTR)):
test_name = mk_test_name(name, getattr(v, "__name__", v), i)
if hasattr(func, UNPACK_ATTR):
if isinstance(v, tuple) or isinstance(v, list):
add_test(test_name, func, *v)
add_test(cls, test_name, func, *v)
else:
# unpack dictionary
add_test(test_name, func, **v)
add_test(cls, test_name, func, **v)
else:
add_test(test_name, func, v)
add_test(cls, test_name, func, v)
delattr(cls, name)
elif hasattr(func, FILE_ATTR):
file_attr = getattr(func, FILE_ATTR)
process_file_data(name, func, file_attr)
process_file_data(cls, name, func, file_attr)
delattr(cls, name)
return cls

View File

@@ -48,6 +48,8 @@ master_doc = 'index'
# General information about the project.
project = u'DDT'
# pylint: disable-msg=W0622
# - copyright is a builtin
copyright = u'2012, Carles Barrobés'
# The version info for the project you're documenting, acts as replacement for

View File

@@ -3,12 +3,12 @@ from ddt import ddt, data, file_data, unpack
from test.mycode import larger_than_two, has_three_elements, is_a_greeting
class mylist(list):
class Mylist(list):
pass
def annotated(a, b):
r = mylist([a, b])
r = Mylist([a, b])
setattr(r, "__name__", "test_%d_greater_than_%d" % (a, b))
return r

View File

@@ -187,21 +187,21 @@ def test_ddt_data_name_attribute():
def hello():
pass
class myint(int):
class Myint(int):
pass
class mytest(object):
class Mytest(object):
pass
d1 = myint(1)
d1 = Myint(1)
d1.__name__ = 'data1'
d2 = myint(2)
d2 = Myint(2)
data_hello = data(d1, d2)(hello)
setattr(mytest, 'test_hello', data_hello)
setattr(Mytest, 'test_hello', data_hello)
ddt_mytest = ddt(mytest)
ddt_mytest = ddt(Mytest)
assert_is_not_none(getattr(ddt_mytest, 'test_hello_1_data1'))
assert_is_not_none(getattr(ddt_mytest, 'test_hello_2_2'))
@@ -219,33 +219,33 @@ def test_ddt_data_unicode():
if six.PY2:
@ddt
class mytest(object):
class Mytest(object):
@data(u'ascii', u'non-ascii-\N{SNOWMAN}', {u'\N{SNOWMAN}': 'data'})
def test_hello(self, val):
pass
assert_is_not_none(getattr(mytest, 'test_hello_1_ascii'))
assert_is_not_none(getattr(mytest, 'test_hello_2_non_ascii__u2603'))
assert_is_not_none(getattr(Mytest, 'test_hello_1_ascii'))
assert_is_not_none(getattr(Mytest, 'test_hello_2_non_ascii__u2603'))
if is_hash_randomized():
assert_is_not_none(getattr(mytest, 'test_hello_3'))
assert_is_not_none(getattr(Mytest, 'test_hello_3'))
else:
assert_is_not_none(getattr(mytest,
assert_is_not_none(getattr(Mytest,
'test_hello_3__u__u2603____data__'))
elif six.PY3:
@ddt
class mytest(object):
class Mytest(object):
@data('ascii', 'non-ascii-\N{SNOWMAN}', {'\N{SNOWMAN}': 'data'})
def test_hello(self, val):
pass
assert_is_not_none(getattr(mytest, 'test_hello_1_ascii'))
assert_is_not_none(getattr(mytest, 'test_hello_2_non_ascii__'))
assert_is_not_none(getattr(Mytest, 'test_hello_1_ascii'))
assert_is_not_none(getattr(Mytest, 'test_hello_2_non_ascii__'))
if is_hash_randomized():
assert_is_not_none(getattr(mytest, 'test_hello_3'))
assert_is_not_none(getattr(Mytest, 'test_hello_3'))
else:
assert_is_not_none(getattr(mytest, 'test_hello_3________data__'))
assert_is_not_none(getattr(Mytest, 'test_hello_3________data__'))
def test_feed_data_with_invalid_identifier():