Checkpoint for a whole subnet

tl;dr: is it possible to make a checkpoint for a whole sub-network and learn its weights with a nice api?

I’ve learned that i can replace

result =, y)[0]


result = torch.utils.checkpoint.checkpoint(, x, y)[0]

where in my case self. net is a sub-model with some logic and several layers, and x and y are inputs from some earlier layer. As far as I understand, x and y will be stored during forward, but no other value inside self. net. On the backward pass, the gradient coming from result will see the checkpoint and recalculate everything to get the gradient for x and y. This trades computation for memory.

Previously i didn’t need the gradients in self. net (it was pre-learned), but now I do. So I would like to have something like

result = torch.utils.checkpoint.checkpoint(,, x, y)[0]

As far as I understand, technically this is very similar to the first case, but the API doesn’t allow it as self. net. parameters() is an iterator.

I’m not very experienced with python and pytorch (only started last summer), but I managed to do the following (and it seems to work):

class _NetCheckpointWrapper:
    def __init__(self, net, x, y): = net
        self.x = x
        self.y = y

    def __call__(self, *args, **kwargs):
        return, self.bias)

# forward:
wrapper = _NetCheckpointWrapper(, x, y)
net_params = tuple(
result = torch.utils.checkpoint.checkpoint(wrapper, *net_params)[0]

This doesn’t look very good to me as a programmer, but maybe it is idiomatic in python/pytorch? I’m not sure about the usage of net_params – coming from c++ i’m afraid net_params might refer to different objects than self. net. parameters(), or that it might change in the future. It is quite verbose as well.

Is there a better way to achieve that?


The short answer is that you don’t need to worry, and if all the parameters of do require gradients, you can simply do result = torch.utils.checkpoint.checkpoint(, x, y)[0] and use loss.backward() and all the parameters inside will have their .grad field populated.

The longer answer is that the whole limitation about not supporting autograd.grad with checkpoint is to make this simple construct possible. You don’t need to pass all the things that require gradients as input to the checkpoint. Only the inputs to your network.

1 Like

Thanks for the answer!

I tried that first, and i get UserWarning: None of the inputs have requires_grad=True. Gradients will be None
  warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
Traceback (most recent call last):
  File "..", line 249, in train_on
  File "/home/madam/bin/anaconda3/lib/python3.7/site-packages/torch/", line 166, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/madam/bin/anaconda3/lib/python3.7/site-packages/torch/autograd/", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

My thinking was that checkpoint removes all requires_grad by design and returns them for the given parameters and therefore it’s not working.

I’m sorry I don’t understand this. autograd.grad is does forward differentiation, doesn’t it? my understanding is that i’m doing backward diff only (but maybe it internally inverts the computation tree?).

I’m sorry I don’t understand this. autograd.grad is does forward differentiation, doesn’t it?

No it does backward mode as well. Only that instead of populating the .grad field of all the leaf Tensor, it computes and returns the gradients for a list of Tensor that is given.

My thinking was that checkpoint removes all requires_grad by design

No the output of the checkpoint module should require gradients. If you put the whole network into it, you might need to set requires_grad=True on the input though. Does it help?

1 Like

Thanks. Yes, it does help. Although I’m not totally happy with the result either (setting requires_grad=true to a tensor that doesn’t need it):

if y.requires_grad is False:
   y.requires_grad_(True)  # require gradient for y even when learning net, otherwise checkpoint doesn't work
   result = torch.utils.checkpoint.checkpoint(, x, y)

Unfortunately, it is a limitation of the checkpoint module. At least one input to the module needs to require gradients.

Since usually, the first module of your net has some weights that require gradients, it should not have any impact on the graph that is created.

1 Like