Very cool - this workaround seems to do the trick perfectly!!
Thanks a lot
Here’s a standalone implementation of it for reference:
import torch
from torch import nn
from torchvision import models
from torch.utils.checkpoint import checkpoint
class ModuleWrapperIgnores2ndArg(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self,x, dummy_arg=None):
assert dummy_arg is not None
x = self.module(x)
return x
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(*list(models.resnet18(pretrained=False).children())[:5])
self.fc1 = nn.Linear(200704, 2)
self.dummy_tensor = torch.ones(1, dtype=torch.float32, requires_grad=True)
self.module_wrapper = ModuleWrapperIgnores2ndArg(self.features)
def forward(self, x):
#x = checkpoint(self.features, x)
x = checkpoint(self.module_wrapper,x,self.dummy_tensor)
print(x.shape)
x = x.view(x.size(0), -1)
x = self.fc1(x)
return x
model = MyModel().cuda()
x = torch.randn(1, 3, 224, 224).cuda()
output = model(x)
output.mean().backward()
print(model.features[0].weight.grad)