Torch.utils.checkpoint does not support List inputs

I attempted to utilize torch.utils.checkpoint in my code, but I discovered that it does not support List inputs.

Below is the standard usage of torch.utils.checkpoint:

import torch
import torch.nn as nn
import torch.utils.checkpoint as cp

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.fc1 = nn.Linear(784, 1024)
        self.fc2 = nn.Linear(1024, 1024)
        self.fc3 = nn.Linear(1024, 512)
        self.fc4 = nn.Linear(512, 10)

    def forward(self, x):
        # Apply the first 2 layers
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        # Apply the checkpointed layers
        x = cp.checkpoint(self._checkpointed_forward, x)

        # Apply the last 1 layer
        x = self.fc4(x)
        return x

    def _checkpointed_forward(self, x):
        x = nn.functional.relu(self.fc3(x))
        return x

# Create a random input tensor
x = torch.randn(2, 784)

# Create the neural network
model = MyNet()

# Compute the output
output = model(x)

# Compute the gradients
output.sum().backward()
# import pdb; pdb.set_trace()

# Print the gradients of the first layer
print(model.fc1.weight.grad)

If I modify the above code as follows (the main change is using a List in _checkpointed_forward):

class MyNet(nn.Module):
    def forward(self, x):
        ...
        x = cp.checkpoint(self._checkpointed_forward, [x])
        x = self.fc4(x)
        return x

    def _checkpointed_forward(self, x):
        x = nn.functional.relu(self.fc3(x[0]))
        return x
...
output.sum().backward()
print(model.fc1.weight.grad)

I observe that model.fc1.weight.grad = None in this case.
Additionally, I set pdb.set_trace() in the class CheckpointFunction:

class CheckpointFunction(torch.autograd.Function):
    def backward(ctx, *args):
        import pdb; pdb.set_trace()

I notice that CheckpointFunction.backward is not even executed.

I am curious about why the code behaves in this manner.
(torch.version = 1.10.0)

@albanD @ptrblck @richard Could anyone give some advice?

List inputs are supported if you specify use_reentrant=False to activation checkpoint. (You may need a later version of PyTorch).
The reason is that checkpoint with use_reentrant=True does not check that tensors hidden in lists are requires grad or not. At least one top-level tensor needs to have requires grad for use_reentrant=True, but this is not a constraint for use_reentrant=False. You can check the documentation here for more information torch.utils.checkpoint — PyTorch 2.1 documentation.

1 Like