About the the implementation of activation checkpointing

I read the source code and found that when using activation checkpointing, the block forward is in torch.no_grad and the outputs are supposed to have requires_grad set to False. However, after I called the checkpoint, the output tensor required gradient, and I couldn’t figure out why this happened. Hope anyone could help me on this.

I’m not sure what your exact use case is, but the docs for checkpoint_sequential mention:

All segments except the last will run in torch.no_grad() manner, i.e., not storing the intermediate activations. The inputs of each checkpointed segment will be saved for re-running the segment in the backward pass.

Could this explain the behavior you are seeing?

hi @ptrblck I’m not facing any problem on its usage but just curious about the what is actually going on. Here is a small demo to reproduce the results:

import torch
from torch import nn
from torch.utils.checkpoint import checkpoint

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(100, 100)
        self.blocks = nn.Sequential(*[
            nn.Linear(100, 100),
            nn.Linear(100, 100)
        ])

    def forward(self, x):
        x = self.linear(x)
        x1 = x2 = x

        with torch.no_grad():
            for block in self.blocks:
                x1 = block(x1)
                print(f'no_grad requires_grad: {x1.requires_grad}')

        for block in self.blocks:
            x2 = checkpoint(block, x2)
            print(f'checkpoint requires_grad: {x2.requires_grad}')

        return x2

model = Model()
x = torch.randn((2, 100))
outs = model(x)

The outputs are as follows:

no_grad requires_grad: False
no_grad requires_grad: False
checkpoint requires_grad: True
checkpoint requires_grad: True

For checkpoint outputs, the requires_grad are set to True and this differs from the description of activation checkpoint.

Thanks for the code snippet as I see where the unexpected behavior is coming from.
If you add a debug print statement into the forward of CheckpointFunction here, you’ll see that output indeed does not require gradients anymore. However, the returned tensor returns True in its requires_grad attribute again as it seems to be a properly of an autograd.Function:

class MyCheckpointFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        print(x.requires_grad)
        
        with torch.no_grad():
            y = x + 1
        print(y.requires_grad)
        return y
    
f = MyCheckpointFunction.apply
x = torch.randn(1, requires_grad=True)
out = f(x)
print(out.requires_grad)

Output:

True
False
True

Yes! It is exactly the same misbehaviour I’ve noticed in previous experiments but I couldn’t find where the attributes are modified. Do you have any idea about this?

BTW, I found that the requires_grad attribute of the output tensor also depends on the inputs. Once I set the requires_grad=False for inputs, the outputs tensors are also requires_grad=False.

I don’t think this is a misbehavior, but exactly how custom autograd.Functions are designed and have to act on input tensors to avoid breaking the computation graph.
E.g. a custom autograd.Function can be used to implement any operation from 3rd party libs (such as numpy) which Autograd cannot track. If the output tensor would not require gradients anymore, your layer would never work inside a proper computation graph.

Thanks for your explanation. This makes things much clear for me.