Checkpoint with no grad requiring inputs PROBLEM

Very cool - this workaround seems to do the trick perfectly!!

Thanks a lot :smiley:

Here’s a standalone implementation of it for reference:

import torch
from torch import nn
from torchvision import models
from torch.utils.checkpoint import checkpoint

class ModuleWrapperIgnores2ndArg(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self,x, dummy_arg=None):
        assert dummy_arg is not None
        x = self.module(x)
        return x

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(*list(models.resnet18(pretrained=False).children())[:5])
        self.fc1 = nn.Linear(200704, 2)
        self.dummy_tensor = torch.ones(1, dtype=torch.float32, requires_grad=True)
        self.module_wrapper = ModuleWrapperIgnores2ndArg(self.features)

    def forward(self, x):
        #x = checkpoint(self.features, x)
        x = checkpoint(self.module_wrapper,x,self.dummy_tensor)
        print(x.shape)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

model = MyModel().cuda()
x = torch.randn(1, 3, 224, 224).cuda()
output = model(x)
output.mean().backward()
print(model.features[0].weight.grad)
9 Likes