How to write unittests? Best practices discussion

When we’re building new models we often write alot of tests iteratively, but I rarely see this formalized.

My current tests are built by inheriting from tests.common (link by separately installing a python package into my environment that only contains whatever is in common. (I made a feature-request (issue #5045) that didn’t receive any love yet)

import sys
from setuptools import setup

    name = "torchtestcommon",        # what you want to call the archive/egg
    version = '0.4.0a0',
    packages=["torchtest"],    # top-level python modules you can import like
    dependency_links = [],  
    package_data = {},
    author="Pytorch contributors",
    author_email = "",
    description = "Copy-paste of pytorch/test/ into ./torchtest/",    

After cd torchtestcommon; pip install -e . I can write stuff like

from torchtest.common import TestCase
testing = TestCase()


from thisfancyneuralnetworkname import MyModel
from torchtest.common import TestCase
import torch

input_ex = torch.ones([100,100])
output_ex = torch.ones([100,100])
class TestMyModel(TestCase):
    def test_forward(self):
        model = MyModel()
        y = model(input_ex)
        self.assertEqual(y.size(), output_ex.size())
        self.assertAlmostEqual(y, output_ex, places=3)

And then run pytest .

How do you do it? Or any link to best practices/example repositories?


Just an update on this, there’s now torch.testing which you can use similarly to numpy.testing;

In [8]: x = torch.ones(1)
In [9]: y = x+1
In [10]: torch.testing.assert_allclose(x,y)
AssertionError                            Traceback (most recent call last)
<ipython-input-10-3b698e9eddbf> in <module>()
----> 1 torch.testing.assert_allclose(x,y)

/usr/local/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/testing/ in assert_allclose(actual, expected, rtol, atol, equal_nan)
     55     raise AssertionError(msg.format(
     56         rtol, atol, list(index), actual[index].item(), expected[index].item(),
---> 57         count - 1, 100 * count / actual.numel()))

AssertionError: Not within tolerance rtol=0.0001 atol=1e-05 at input[0] (1.0 vs. 2.0) and 0 other locations (100.00%)