How to confirm that pytorch code is working as intended / Unit Testing a pytorch code?

Question is exactly as the title says. Suppose I have a piece of code. I want to ensure that the back prop updates are all happening as I expect them to be. I am looking for something along the lines of unit testing or a principled approach to it.


You could create unit tests by storing the values of your parameters and check for updates after a training iteration. The concept is explained here. It’s for Tensorflow, but you can easily adapt it to PyTorch.


That was an amazing post. Thank you for the answer. Time to apply it to my own code. Thanks a lot.

@ptrblck In the blog you mentioned, I found the code for testing in tensorflow ( and am trying to convert it for pytorch.

Is there any simple way to run a line of code defined elsewhere like does in tensorflow? In particular I want to replicate what is happening in L83 ( but I can’t think of easing the process in pytorch.

In PyTorch you don’t need to work with session like in Tensorflow. Just try to .clone() your values.

@ptrblck Thanks for the suggestion. Is there a similar trick which can be used for L84 ( I was planning on taking the exact operation into the function, but I am not sure if that is correct practice.

It seems op is the “training operation”, e.g. op = tf.train.AdamOptimizer().minimize(var).
Using PyTorch you can just perform a training step:

output = model(data)
loss = criterion(output, target)

Does this help?

I’m using gradchecks to unittest my models. It’s helping a lot…

from torch.autograd.gradcheck import gradcheck

    def test_sanity(self):
        input = (Variable(torch.randn(20, 20).double(), requires_grad=True), )
        model = nn.Linear(20, 1).double()
        test = gradcheck(model, input, eps=1e-6, atol=1e-4)

tip: make sure you use .double() Gradcheck will fail with only 32 bit floats.

Also… Interpreting gradcheck errors


I’ve ported the tests from mltest to pytorch. It’s available at suriyadeepan/torchtest.

You can install it via pip

pip install torchtest