I read the source code and found that when using activation checkpointing, the block forward is in torch.no_grad
and the outputs are supposed to have requires_grad
set to False. However, after I called the checkpoint, the output tensor required gradient, and I couldn’t figure out why this happened. Hope anyone could help me on this.
I’m not sure what your exact use case is, but the docs for checkpoint_sequential
mention:
All segments except the last will run in
torch.no_grad()
manner, i.e., not storing the intermediate activations. The inputs of each checkpointed segment will be saved for re-running the segment in the backward pass.
Could this explain the behavior you are seeing?
hi @ptrblck I’m not facing any problem on its usage but just curious about the what is actually going on. Here is a small demo to reproduce the results:
import torch
from torch import nn
from torch.utils.checkpoint import checkpoint
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(100, 100)
self.blocks = nn.Sequential(*[
nn.Linear(100, 100),
nn.Linear(100, 100)
])
def forward(self, x):
x = self.linear(x)
x1 = x2 = x
with torch.no_grad():
for block in self.blocks:
x1 = block(x1)
print(f'no_grad requires_grad: {x1.requires_grad}')
for block in self.blocks:
x2 = checkpoint(block, x2)
print(f'checkpoint requires_grad: {x2.requires_grad}')
return x2
model = Model()
x = torch.randn((2, 100))
outs = model(x)
The outputs are as follows:
no_grad requires_grad: False
no_grad requires_grad: False
checkpoint requires_grad: True
checkpoint requires_grad: True
For checkpoint outputs, the requires_grad
are set to True and this differs from the description of activation checkpoint.
Thanks for the code snippet as I see where the unexpected behavior is coming from.
If you add a debug print statement into the forward
of CheckpointFunction
here, you’ll see that output
indeed does not require gradients anymore. However, the returned tensor returns True
in its requires_grad
attribute again as it seems to be a properly of an autograd.Function
:
class MyCheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
print(x.requires_grad)
with torch.no_grad():
y = x + 1
print(y.requires_grad)
return y
f = MyCheckpointFunction.apply
x = torch.randn(1, requires_grad=True)
out = f(x)
print(out.requires_grad)
Output:
True
False
True
Yes! It is exactly the same misbehaviour I’ve noticed in previous experiments but I couldn’t find where the attributes are modified. Do you have any idea about this?
BTW, I found that the requires_grad
attribute of the output tensor also depends on the inputs. Once I set the requires_grad=False
for inputs, the outputs tensors are also requires_grad=False
.
I don’t think this is a misbehavior, but exactly how custom autograd.Function
s are designed and have to act on input tensors to avoid breaking the computation graph.
E.g. a custom autograd.Function
can be used to implement any operation from 3rd party libs (such as numpy) which Autograd cannot track. If the output tensor would not require gradients anymore, your layer would never work inside a proper computation graph.
Thanks for your explanation. This makes things much clear for me.