Checkpoint with no grad requiring inputs PROBLEM

The first thing that happens in my model forward method is calling checkpoint few times using several feature extractors.
However, I get the following warning:

 UserWarning: None of the inputs have requires_grad=True. Gradients will be None
  warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")

This issue seems to be described here:

Is there any recommended workaround for this right now ?

At first I thought about doing something like:

input_data.requires_grad_(True)

and then do:
y = checkpoint(foo, input_data)

But I think that it will also treat the input as parameters (like done in sensitivity/saliency map creation), which isn’t something I need or want to add to the computation.

Should I add/use some sort of a dummy function before calling checkpoint?
any suggestions?

2 Likes

I think you don’t need checkpoints, if no gradients are calculated.
Checkpoints will trade some memory for compute by re-calculating the intermediate activations needed for the weight updates again. Since apparently no gradients are needed in your computation, you won’t get any benefit from checkpointing.

Thanks for the reply @ptrblck : )

I do need gradients calculated, and the reason that I’m using it is to trade memory and performance speed.
I probably did not explain myself well, sorry.

See, for example, the minimalist reproducing code that was mentioned in the github issue that I linked above:

import torch
from torch.utils.checkpoint import checkpoint

m = torch.nn.Linear(4,3)
x = torch.randn(10,4)

z1 = checkpoint(lambda _:m(_), x)
print(z1.requires_grad)

z2 = m(x)
print(z2.requires_grad)

If you run this code it will output:

.../miniconda3/envs/py36torch/lib/python3.6/site-packages/torch/utils/checkpoint.py:20: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
  warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
False
True

As you can see, no gradients will be computed in the checkpoint example (for all of the parameters inside torch.nn.Linear(…) ).
And on the other hand, gradients WILL be computed in the example that doesn’t use checkpoint. (Even though, all of its input have requires_grad=-False ! )

This is most problematic when you want to use checkpoint as the first operation(s) in your network.

For example, to explain my motivation - I want to use it to calculate features for several images and then combine the features (and have all in one computation graph).
Each image feature extraction is checkpointed, so hopefully this will allow me to not get OOM.

Of course, maybe I’m missing something here, so please enlighten me if I do : )

1 Like

Thanks for the explanation. I misunderstood the issue.
I played around a bit and also noticed, that the problem occurs when you checkpoint the first operation.

However, could you explain your use case a bit more?
You have a model which calculates features for several images.
I assume each image has an own “computation path”. Is this right?
Each of this submodules are checkpointed to save some memory and somehow you get the gradient error from your first post.

Did you freeze the model and just use it for the feature extraction or do you train it?
Could you post a small dummy model showing your use case?

2 Likes

@ptrblck

Text description first: (code follows)

I train the model from scratch, there is only one sub-graph that does the feature extraction, but I use it multiple times, once per image, in a single forward pass, so the feature extractor params are shared. You can think about it as Siamese network gone wild :slight_smile:
Multiple images are inserted into the forward function,

Each image goes through:

  1. An InceptionResnetV2 feature extractor, using only the first few layers, and no fully connected layers.
  2. A global max pooling layer, so now we have a 1d vector of features per image (for example, 1088 features)

After we did this for, for example, 10 images, we now have 10 individual 1d vectors, each of size 1088.
We stack them into a single matrix, which has the shape (10,1088)
and, for example, perform another global max pooling on them, reaching a 1d feature vector of size 1088, which represents max features from all images.
From that point, standard fully connected layer and activation that is relevant to negative log likelihood is used.

Code:
Will try to demonstrate it here in code. The code is not self contained but hopefully demonstrates the scenario well, at least for static viewing.
Please tell me if it’s not and I’ll improve it.

class ReproduceIssue(nn.Module):
    def __init__(self):
        super().__init__()
       
        #this is an inception resnet v2 model, but only the first K layers (and no fully connect layers)
        self.feature_extractor = inceptionresnetv2.inceptionresnetv2(
            pretrained=False,
            num_classes=0,
            logical_units_num=14, #keep only the first 14 parts (it's up to mixed_6a)
            input_channels_num=1,
            final_global_pooling='max' #end it with global max pooling to get a 1d vector per sample
        )
        self.fc1 = MyDense(1088, 2, activation=None, batch_norm=False, dropout_rate=None)
        self._init_vars() #...

    def forward(self, *args):
        extracted_features = []
        for x in args:
            #
            #curr_feat = self.feature_extractor(x) #orig before checkpointing
            curr_feat = checkpoint(self.feature_extractor, x)
            extracted_features.append(curr_feat)

        #now that we have all features, stack and perform an additional global max pooling
        stacked_slices_features = torch.stack(slices_features)
        stacked_slices_features = stacked_slices_features[:, 0, ...]
        permuted_stacked = stacked_slices_features.permute(2, 1, 0, 3)
        vol_features = F.max_pool2d(permuted_stacked, kernel_size=permuted_stacked.shape[2:])
        
        #now, take our 1d vector and continue into a fully connected layer, and final activation
        logits = self.fc1(vol_features[:, :, 0, 0])
        preds = nn.LogSoftmax(dim=1)(logits) 
        return preds

Note: For the sake of simplicity, assume that a minibatch will always be of size 1.
(My code does support minibatches of arbitrary sizes, but it overcomplicates the demonstration code so I did not include it)

Thanks for the clarification and the nicely simplified code!
Maybe we could add a dummy layer between the self.feature_extractor and the append.
I created a small example with a resnet18, which at least doesn’t throw any errors.
Could you try that?

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.features = nn.Sequential(*list(models.resnet18(pretrained=False).children())[:5])
        self.fc1 = nn.Linear(200704, 2)
        
    def forward(self, x):
        x = self.features(x)
        x = checkpoint(lambda x: x, x)
        print(x.shape)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

model = MyModel()
x = torch.randn(1, 3, 224, 224)
output = model(x)
output.mean().backward()
print(model.features[0].weight.grad)

Thanks @ptrblck !

I believe that in the example that you created it doesn’t throw any errors because you are not checkpointing the first layer, as you first run self.features on the input, and then only checkpoint an identity mapping lambda,
which kinda defeats the purpose.

Perhaps your intention was to create a dummy layer before the feature extraction ? This is what I’m trying to do right now.

This is my attempt of a workaround, now in standalone code thanks to your example : )
If you have an idea for making a dummy layer that is more “light weight” and/or some code base solution (that makes checkpoint function ignore the fact that the input has no gradients) I’ll appreciate it a lot :slight_smile:

import torch
from torch import nn
from torchvision import models
from torch.utils.checkpoint import checkpoint

class DummyLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.dummy = nn.Parameter(torch.ones(1, dtype=torch.float32))
    def forward(self,x):
        return x + self.dummy - self.dummy #(also tried x+self.dummy)

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.features = nn.Sequential(*list(models.resnet18(pretrained=False).children())[:5])
        self.fc1 = nn.Linear(200704, 2)

        self.dummy_layer = DummyLayer()

    def forward(self, x):
        x = self.dummy_layer(x)
        x = checkpoint(self.features, x)
        print(x.shape)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

model = MyModel().cuda()
x = torch.randn(1, 3, 224, 224).cuda()
output = model(x)
output.mean().backward()
print(model.features[0].weight.grad)

While it works in this synthetic example (and using these “small” images), in my actual scenario, I’m on the very edge of gpu memory even for a single slice (and I need all of it for proper feature extraction),
and even this dummy layer makes me get OOM (even if I don’t use checkpointing)

I think I’ve misunderstood the checkpoint util, thus my example seems to do nothing.
I thought you have to checkpoint one layer somewhere in your model and all layers before this one will be checkpointed.
As it seems you would have to checkpoint all layers together, which makes more sense.
I try to have a look at your code later on, since today is quite busy.

1 Like

A small update. I’m reading a tutorial on checkpointing models from the original author @Priya_Goyal and this part seems to be interesting for your use case:

NOTE: In case of checkpointing, if all the inputs don’t require grad but the outputs do, then if the inputs are passed as is, the output of Checkpoint will be variable which don’t require grad and autograd tape will break there. To get around, you can pass a dummy input which requires grad but isn’t necessarily used in computation.

6 Likes

Very cool!

Until now I tried doing requires_grad_() on the actual input (which effects the computation and creates overhead) - but the idea of additional dummy input which isn’t used in the computation sounds great!!
Trying it now and will update.

1 Like

Very cool - this workaround seems to do the trick perfectly!!

Thanks a lot :smiley:

Here’s a standalone implementation of it for reference:

import torch
from torch import nn
from torchvision import models
from torch.utils.checkpoint import checkpoint

class ModuleWrapperIgnores2ndArg(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self,x, dummy_arg=None):
        assert dummy_arg is not None
        x = self.module(x)
        return x

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(*list(models.resnet18(pretrained=False).children())[:5])
        self.fc1 = nn.Linear(200704, 2)
        self.dummy_tensor = torch.ones(1, dtype=torch.float32, requires_grad=True)
        self.module_wrapper = ModuleWrapperIgnores2ndArg(self.features)

    def forward(self, x):
        #x = checkpoint(self.features, x)
        x = checkpoint(self.module_wrapper,x,self.dummy_tensor)
        print(x.shape)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

model = MyModel().cuda()
x = torch.randn(1, 3, 224, 224).cuda()
output = model(x)
output.mean().backward()
print(model.features[0].weight.grad)
9 Likes

Nice work! I think we both learned quite a bit in this thread :wink:

3 Likes

hehe - indeed !
Thanks again :slight_smile:

Hey, I am facing a similar error as you mention in the question. However, doing input.requires_grad = True fixes the error. Could you elaborate as to why you are using a dummy variable and not the input itself? What exactly is the overhead you are talking about?

sure, I’ll explain.
When you use requires_grad=True on the input data, you are telling pytorch to calculate the gradient for it.
This is actually a practical feature in few cool use cases, but it does create an overhead if this is not what you really desire, since you usually have no intention of changing, for example, the pixels values of the input image.

The overhead is pretty straightforward, the gradient tensor is calculated for the input, which adds to both the calculation time and also the GPU memory that is used when storing this gradient.

In some cases, like small images, I guess that it won’t matter that much. However, in my use case, which involve having as input very big 3d volumes of data, it does matter a lot, because , even if we ignore the calculation time overhead, it would also add over 1 GB of gpu memory usage.

4 Likes

May I ask if there is any convenient solution for torch.utils.checkpoint.checkpoint_sequential? Thanks!

Great job! It works to my case.

I believe using use_reentrant=False argument in torch.utils.checkpoint.checkpoint() will solve this problem.

3 Likes

use_reentrant=False solves a similar problem (with another model)

1 Like

see gradient checkpointing disables requires_grad when freezing part of models (fix with use_reentrant=False) · Issue #21381 · huggingface/transformers · GitHub

1 Like