Can't use checkpoint_sequential

import torch
from torch import nn
from torch.utils.checkpoint import checkpoint_sequential
from torchvision.models import resnet50

model = resnet50()

modules = [module for k, module in enumerate(model.children())]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = model.to(device)
inputs = torch.randn(4, 3, 224, 224).to(device)
inputs.requires_grad = True
output = checkpoint_sequential(modules, len(modules), inputs)
criterion = nn.CrossEntropyLoss()
target = torch.empty(1, dtype=torch.long).random_(5).to(device)
loss = criterion(output, target)
loss.backward()
optimizer.step()

I got :arrow_down:

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py in forward(self, input)
112
113 def forward(self, input: Tensor) → Tensor:
→ 114 return F.linear(input, self.weight, self.bias)
115
116 def extra_repr(self) → str:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (8192x1 and 2048x1000)

How to use checkpoint_sequential for resnet50? Any suggestion? :sob: @ptrblck_de

checkpoint_sequential supports a list of modules/functions in order (sequentially).

Please check the resnet50 forward function:

x = torch.flatten(x, 1) is not included in your modules = [module for k, module in enumerate(model.children())].

You can check by printing len(modules), which is 10, but there are 11 operations in forward. Thus, x = torch.flatten(x, 1) is not considered in your modules.