Training with gradient checkpoints (torch.utils.checkpoint) appears to reduce performance of model

I have a snippet of code that uses gradient checkpoints from torch.utils.checkpoint to reduce GPU memory:

    if use_checkpointing:
        res2, res3, res4, res5 = checkpoint.checkpoint(self.resnet_backbone, data['data'])
        fpn_p2, fpn_p3, fpn_p4, fpn_p5, fpn_p6 = checkpoint.checkpoint(self.fpn, res2, res3, res4, res5)
    else:
        res2, res3, res4, res5 = self.resnet_backbone(data['data'])
        fpn_p2, fpn_p3, fpn_p4, fpn_p5, fpn_p6 = self.fpn(res2, res3, res4, res5)

When I use gradient checkpointing, after training my model, the performance is worse. Does anyone know why this might be?

1 Like

Hi,

This most likely happens because the first part of your model doesn’t get gradient because of some quirks of how checkpointing works.
Can you try making data['data'] require gradients before giving it to the checkpoint? (you can ignore the computed gradient, just add a data['data'].requires_grad_() if it is a Tensor.

2 Likes

When I do that I appear to get a new error:

[4]<stderr>:    loss.backward()
[4]<stderr>:  File "/pythonhome_pypi/lib/python3.6/site-packages/torch/tensor.py", line 195, in backward
[4]<stderr>:    torch.autograd.backward(self, gradient, retain_graph, create_graph)
[4]<stderr>:  File "/pythonhome_pypi/lib/python3.6/site-packages/torch/autograd/__init__.py", line 99, in backward
[4]<stderr>:    allow_unreachable=True)  # allow_unreachable flag
[4]<stderr>:  File "/pythonhome_pypi/lib/python3.6/site-packages/torch/autograd/function.py", line 77, in apply
[4]<stderr>:    return self._forward_cls.backward(self, *args)
[4]<stderr>:  File "/pythonhome_pypi/lib/python3.6/site-packages/torch/utils/checkpoint.py", line 99, in backward
[4]<stderr>:    torch.autograd.backward(outputs, args)
[4]<stderr>:  File "/pythonhome_pypi/lib/python3.6/site-packages/torch/autograd/__init__.py", line 99, in backward
[4]<stderr>:    allow_unreachable=True)  # allow_unreachable flag
[4]<stderr>:RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

To be specific I added .requires_grad_():

res2, res3, res4, res5 = checkpoint.checkpoint(self.resnet_backbone, data['data'].requires_grad_())

Your loss does not requires gradients anymore because you added an extra requires_grad_()? That looks weird. Did you change anything else by any chance?

Yeah that is all I added, I just double checked and this is the case D:

Can you double check if the outputs of the checkpoint require gradient or not? In both cases.

I tested all cases:

res2, res3, res4, res5 = checkpoint.checkpoint(self.resnet_backbone, data['data'])
print(res2.requires_grad, res3.requires_grad, res4.requires_grad, res5.requires_grad)
res2, res3, res4, res5 = checkpoint.checkpoint(self.resnet_backbone, data['data'].requires_grad_())
print(res2.requires_grad, res3.requires_grad, res4.requires_grad, res5.requires_grad)
res2, res3, res4, res5 = self.resnet_backbone(data['data'])
print(res2.requires_grad, res3.requires_grad, res4.requires_grad, res5.requires_grad)

This outputs:

False False False False
True True True True
False True True True

In the original resnet backbone architecture of my code I have frozen res2 layer, could this be the issue?

Edit: it appears this is the culprit as I can train when I unfreeze the res2 layer. I suppose to fix this I can just .detach() it, ie.

res2, res3, res4, res5 = checkpoint.checkpoint(self.resnet_backbone, data['data'].requires_grad_())
res2 = res2.detach()

Ok,

So adding the requires grad on the data does fix the issue that the backbone would not be learning.

For the res2 freezing I’m not sure what you mean here.
The checkpoint has this behavior that it make all outputs require gradient, because it does not know which elements will actually require it yet.
Note that in the final computation during the backward, that gradient (should) will be discarded and not used, so the frozen part should remain frozen. Even though you don’t see it in the forward pass.

Adding the detach() after the checkpoint will work as well :slight_smile:

Yeah so in my model, I have code that detaches the gradients for res2, ie:

        with torch.no_grad():
            out = self.mod1(img)
            out = self.mod2(self.pool2(out))
            out = self.mod3(self.pool3(out)).detach()
        res2 = out

And since checkpoint has the behaviour that it makes all outputs require a gradient, when it tries to do it here I think it gets conflicted. Since in the model definition a gradient is not required but then the flag requires_grad is true. Hence it outputs the error [4]<stderr>:RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn. Or so I think is the case…

Ho good catch !
Another quirk of the checkpoint module…
So, here indeed, it might be simpler to remove this special code and detach() outside of the checkpoint.

Note that in this code, the detach() is not needed as the no_grad() block already prevent any graph from being created.

Hey guys :slight_smile:

It seems that I have a similar problem. My model is made up of a backbone CNN (namely ResNet50) and some individual layers at the end. I added checkpoint.checkpoint_sequential() for the resnet and checkpoint.checkpoint() for the individual layers at the end. The performance of my model drops heavily.

So I checked requires_grad on the outputs of the checkpoints as proposed in this post and in the first epoch everything was fine, but in the second epoch it turned to False. I tried to add
requires_grad_ ()
which resolved the problem for checkpoint.checkpoint().

However, for checkpoint.sequential_checkpoint it still gives False in the second epoch. The weird thing is, that during evaluation, i.e. with torch.no_grad(), requires_grad on the output of the checkpoints resolved to Ture :face_with_raised_eyebrow:

If anyone has any idea, I would be pretty thankful :slight_smile:

Edit: changed the checkpointing inside resnet to checkpoint.checkpoint too and not requires_grad is set to false for both after the checkpointing after the first epoch…