Checkpoint didn’t support list output

I attempted to utilize torch.utils.checkpoint in my code, but I discovered that it does not support List inputs.

Below is the standard usage of torch.utils.checkpoint:

import torch
import torch.nn as nn
import torch.utils.checkpoint as cp

class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.fc1 = nn.Linear(784, 1024)
        self.fc2 = nn.Linear(1024, 1024)
        self.fc3 = nn.Linear(1024, 512)
        self.fc4 = nn.Linear(512, 10)

    def forward(self, x):
        # Apply the first 2 layers
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        # Apply the checkpointed layers
        x = cp.checkpoint(self._checkpointed_forward, x)

        # Apply the last 1 layer
        x = self.fc4(x)
        return x

    def _checkpointed_forward(self, x):
        x = nn.functional.relu(self.fc3(x))
        return x

# Create a random input tensor
x = torch.randn(2, 784)

# Create the neural network
model = MyNet()

# Compute the output
output = model(x)

# Compute the gradients
output.sum().backward()
# import pdb; pdb.set_trace()

# Print the gradients of the first layer
print(model.fc1.weight.grad)

If I modify the above code as follows (the main change is using a List in _checkpointed_forward):

class MyNet(nn.Module):
    def forward(self, x):
        ...
        x = cp.checkpoint(self._checkpointed_forward, [x])
        x = self.fc4(x)
        return x

    def _checkpointed_forward(self, x):
        x = nn.functional.relu(self.fc3(x[0]))
        return x
...
output.sum().backward()
print(model.fc1.weight.grad)

I observe that model.fc1.weight.grad = None in this case.
Additionally, I set pdb.set_trace() in the class CheckpointFunction:

class CheckpointFunction(torch.autograd.Function):
    def backward(ctx, *args):
        import pdb; pdb.set_trace()

I notice that CheckpointFunction.backward is not even executed.

I am curious about why the code behaves in this manner.

torch.utils.checkpoint is used for activation checkpointing trading compute for memory. I’m unsure if you want to save a list object but if so use torch.save.

Hi, I updated the code example, could you help to check it? @ptrblck