import torch
from torch import nn
from torch.utils.checkpoint import checkpoint_sequential
from torchvision.models import resnet50
model = resnet50()
modules = [module for k, module in enumerate(model.children())]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
inputs = torch.randn(4, 3, 224, 224).to(device)
inputs.requires_grad = True
output = checkpoint_sequential(modules, len(modules), inputs)
criterion = nn.CrossEntropyLoss()
target = torch.empty(1, dtype=torch.long).random_(5).to(device)
loss = criterion(output, target)
loss.backward()
optimizer.step()
I got
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py in forward(self, input)
112
113 def forward(self, input: Tensor) → Tensor:
→ 114 return F.linear(input, self.weight, self.bias)
115
116 def extra_repr(self) → str:
RuntimeError: mat1 and mat2 shapes cannot be multiplied (8192x1 and 2048x1000)