How to use torch.utils.checkpoint

Just as the title says, how should I modify my forward pass to use this function?

My current forward pass is:

    def forward(self, x):
        out = F.relu(self.pool1(self.conv1(x)))
        out = F.relu(self.pool2(self.conv2(out)))
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

Many thanks!

I like @Priya_Goyal’s tutorial on checkpointing.
Note that it wasn’t updated in a while and uses an old PyTorch version, but the general workflow should be the same.

2 Likes