I have defined a simple model that trains on the CIFAR10 dataset included in torchvision. This model converges to a loss of ~0.72 after 10 epochs of training. Here is the model definition:
class CIFAR10Model(nn.Module):
def __init__(self):
super().__init__()
self.cnn_block_1 = nn.Sequential(*[
nn.Conv2d(3, 32, 3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Dropout(0.25)
])
self.cnn_block_2 = nn.Sequential(*[
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Dropout(0.25)
])
self.flatten = lambda inp: torch.flatten(inp, 1)
self.head = nn.Sequential(*[
nn.Linear(64 * 8 * 8, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, 10)
])
def forward(self, X):
X = self.cnn_block_1(X)
X = self.cnn_block_2(X)
X = self.flatten(X)
X = self.head(X)
return X
I then defined and trained a new version of this model, which has been updated to use torch.utils.checkpoint.checkpoint. This model has the following definition (updates were made to exclude nn.Dropout from checkpointing, since Dropout is incompatible with checkpointing):
class CIFAR10Model(nn.Module):
def __init__(self):
super().__init__()
self.cnn_block_1 = nn.Sequential(*[
nn.Conv2d(3, 32, 3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
])
self.dropout_1 = nn.Dropout(0.25)
self.cnn_block_2 = nn.Sequential(*[
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
])
self.dropout_2 = nn.Dropout(0.25)
self.flatten = lambda inp: torch.flatten(inp, 1)
self.linearize = nn.Sequential(*[
nn.Linear(64 * 8 * 8, 512),
nn.ReLU()
])
self.dropout_3 = nn.Dropout(0.5)
self.out = nn.Linear(512, 10)
def forward(self, X):
X = torch.utils.checkpoint.checkpoint(self.cnn_block_1, X)
X = self.dropout_1(X)
X = torch.utils.checkpoint.checkpoint(self.cnn_block_2, X)
X = self.dropout_2(X)
X = self.flatten(X)
X = self.linearize(X)
X = self.dropout_3(X)
X = self.out(X)
return X
However, this model achieves radically worse performance than its uncheckpointed peer: just ~1.59 loss after 10 epochs!
I don’t understand why this is the case. This is my first time using this API; perhaps someone here who understands it better than I can point out where my error is?
For reference, here is the full code and logs: uncheckpointed version, checkpointed version.
