How to confirm that pytorch code is working as intended / Unit Testing a pytorch code?

Question is exactly as the title says. Suppose I have a piece of code. I want to ensure that the back prop updates are all happening as I expect them to be. I am looking for something along the lines of unit testing or a principled approach to it.

5 Likes

You could create unit tests by storing the values of your parameters and check for updates after a training iteration. The concept is explained here. It’s for Tensorflow, but you can easily adapt it to PyTorch.

4 Likes

That was an amazing post. Thank you for the answer. Time to apply it to my own code. Thanks a lot.

@ptrblck In the blog you mentioned, I found the code for testing in tensorflow (https://github.com/Thenerdstation/mltest/blob/master/mltest/mltest.py) and am trying to convert it for pytorch.

Is there any simple way to run a line of code defined elsewhere like sess.run does in tensorflow? In particular I want to replicate what is happening in L83 (https://github.com/Thenerdstation/mltest/blob/master/mltest/mltest.py#L83) but I can’t think of easing the process in pytorch.

1 Like

In PyTorch you don’t need to work with session like in Tensorflow. Just try to .clone() your values.

@ptrblck Thanks for the suggestion. Is there a similar trick which can be used for L84 (https://github.com/Thenerdstation/mltest/blob/master/mltest/mltest.py#L84). I was planning on taking the exact operation into the function, but I am not sure if that is correct practice.

It seems op is the “training operation”, e.g. op = tf.train.AdamOptimizer().minimize(var).
Using PyTorch you can just perform a training step:

...
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()

Does this help?

I’m using gradchecks to unittest my models. It’s helping a lot…

from torch.autograd.gradcheck import gradcheck

    def test_sanity(self):
        input = (Variable(torch.randn(20, 20).double(), requires_grad=True), )
        model = nn.Linear(20, 1).double()
        test = gradcheck(model, input, eps=1e-6, atol=1e-4)
        print(test)

tip: make sure you use .double() Gradcheck will fail with only 32 bit floats.

Also… Interpreting gradcheck errors

3 Likes

I’ve ported the tests from mltest to pytorch. It’s available at suriyadeepan/torchtest.

You can install it via pip

pip install torchtest
9 Likes