I have defined a simple model that trains on the CIFAR10 dataset included in torchvision
. This model converges to a loss of ~0.72 after 10 epochs of training. Here is the model definition:
class CIFAR10Model(nn.Module):
def __init__(self):
super().__init__()
self.cnn_block_1 = nn.Sequential(*[
nn.Conv2d(3, 32, 3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Dropout(0.25)
])
self.cnn_block_2 = nn.Sequential(*[
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
nn.Dropout(0.25)
])
self.flatten = lambda inp: torch.flatten(inp, 1)
self.head = nn.Sequential(*[
nn.Linear(64 * 8 * 8, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, 10)
])
def forward(self, X):
X = self.cnn_block_1(X)
X = self.cnn_block_2(X)
X = self.flatten(X)
X = self.head(X)
return X
I then defined and trained a new version of this model, which has been updated to use torch.utils.checkpoint.checkpoint
. This model has the following definition (updates were made to exclude nn.Dropout
from checkpointing, since Dropout is incompatible with checkpointing):
class CIFAR10Model(nn.Module):
def __init__(self):
super().__init__()
self.cnn_block_1 = nn.Sequential(*[
nn.Conv2d(3, 32, 3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
])
self.dropout_1 = nn.Dropout(0.25)
self.cnn_block_2 = nn.Sequential(*[
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
])
self.dropout_2 = nn.Dropout(0.25)
self.flatten = lambda inp: torch.flatten(inp, 1)
self.linearize = nn.Sequential(*[
nn.Linear(64 * 8 * 8, 512),
nn.ReLU()
])
self.dropout_3 = nn.Dropout(0.5)
self.out = nn.Linear(512, 10)
def forward(self, X):
X = torch.utils.checkpoint.checkpoint(self.cnn_block_1, X)
X = self.dropout_1(X)
X = torch.utils.checkpoint.checkpoint(self.cnn_block_2, X)
X = self.dropout_2(X)
X = self.flatten(X)
X = self.linearize(X)
X = self.dropout_3(X)
X = self.out(X)
return X
However, this model achieves radically worse performance than its uncheckpointed peer: just ~1.59 loss after 10 epochs!
I don’t understand why this is the case. This is my first time using this API; perhaps someone here who understands it better than I can point out where my error is?
For reference, here is the full code and logs: uncheckpointed version, checkpointed version.