Hi all,
I’m trying to train a model on my GPU (RTX 2080 super) using Gradient Checkpointing in order to significantly reduce the usage of VRAM. I’m using torch.utils.checkpoint.checkpoint
.
The model in which I want to apply it is a simple CNN with a flatten layer at the end. Although I think I applied it right I’m not having any memory usage reduction. The memory usage with Gradient Checkpointing is the same as without it, however I do see a increase in the time per epoch (something expected given the nature of Gradient Checkpointing).
I’ve been searching through the forums but I didn’t see anything related with this.
The model definition is the following:
class my_model(torch.nn.Module):
def __init__(self):
super(my_model, self).__init__()
# Input shape [19, 22, 20] -> Output shape [3258]
self.conv1 = torch.nn.Conv2d(in_channels = 20,
out_channels = 50,
kernel_size = 3,
padding = 1)
self.conv2 = torch.nn.Conv2d(in_channels = 50,
out_channels = 25,
kernel_size = 3,
padding = 1)
self.conv3 = torch.nn.Conv2d(in_channels = 25,
out_channels = 10,
kernel_size = 3,
padding = 1)
self.fc1 = torch.nn.Linear(in_features = 19 * 22 * 10,
out_features = 3258)
def flat(self, x):
return torch.flatten(x, 1)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = checkpoint(self.conv2, x)
x = checkpoint(F.relu, x)
x = checkpoint(self.conv3, x)
x = checkpoint(F.relu, x)
x = checkpoint(self.flat, x)
x = checkpoint(self.fc1, x)
return x
I understand that when you checkpoint some operation the activations of that operation are not saved in memory (it’s recomputed when needed by the backward operation).
For show purposes I define some random train and validation data:
class MyDataset(Dataset):
def __init__(self, data, target, transform = None):
self.data = torch.from_numpy(data).float()
self.target = torch.from_numpy(target).float()
self.transform = transform
def __getitem__(self, index):
x = self.data[index]
y = self.target[index]
if self.transform:
x = self.transform(x)
return x, y
def __len__(self):
return len(self.data)
# Random data
x = np.random.randn(8766, 20, 19, 22) # 19 x 22 x 20 array
y = np.random.randn(8766, 3258) # 3258 vector
# Generate train and validation indices
validation_split = 0.1
idxs = list(range(x.shape[0]))
np.random.shuffle(idxs)
split = int(np.floor(validation_split * len(idxs)))
train_idxs, val_idxs = idxs[split:], idxs[:split]
Train and Validation DataLoaders:
# Train dataloader
dataset_train = MyDataset(data = x[train_idxs, :, :, :],
target = y[train_idxs, :])
loader_train = DataLoader(dataset_train,
batch_size = 100,
shuffle = True,
num_workers = 1)
# Validation dataloader
dataset_val = MyDataset(data = x[val_idxs, :, :, :],
target = y[val_idxs, :])
loader_val = DataLoader(dataset_val,
batch_size = 100,
shuffle = True,
num_workers = 1)
And I finally I train the model:
# Instanciate the model
model = my_model.cuda()
# Initialize lowest loss value registered (validation)
lowest_loss_val = 10000
# Initialize the number of epochs without decreasement in loss (validation)
epochs_loss_val = 1
# List for the per epoch losses in train
avg_train_loss_list = []
# List for the per epoch losses in validation
avg_val_loss_list = []
# Model hyperparameters
epochs = 100000
learning_rate = 0.0001
patience = 30 # Early stopping patience
# Loss function and optimizer
loss_function = torch.nn.MSELoss()
optimizer = optim.Adam(model.parameters(),
lr = learning_rate)
# Iterate over epochs
for epoch in range(epochs):
# Record start time of epoch
x1 = time.time()
# Training step
model.train()
loss_train_list = []
for batch_idx, (data, target) in enumerate(loader_train):
# Clean gradients value
optimizer.zero_grad()
# Forward pass
output = model(data.cuda())
# Loss calculation
loss = loss_function(output, target.cuda())
# Append batch loss
loss_train_list.append(loss.item())
# Gradients calculation
loss.backward()
# Optimizer update
optimizer.step()
# Calculate total time spent on epoch
x2 = time.time() - x1
# Training loss averaged over batches (loss per epoch)
avg_train_loss_list.append(np.average(loss_train_list))
# Validation step (earlystopping purposes)
model.eval()
loss_val_list = []
for batch_idx, (data, target) in enumerate(loader_val):
# Forward pass
output = model(data.cuda())
# Loss calculation
loss = loss_function(output, target.cuda())
# Append batch loss
loss_val_list.append(loss.item())
# Validation loss averaged over batches
avg_val_loss_list.append(np.average(loss_val_list))
# Log info
print('Epoch {} | Training Loss {} | Validation Loss {} | Elapsed seconds {}'.format(epoch + 1,
round(avg_train_loss_list[-1], 5),
round(avg_val_loss_list[-1], 5),
round(x2, 5)))
# Earlystopping checking
# If latest loss is lowest than the lower one
if avg_val_loss_list[-1] < lowest_loss_val:
lowest_loss_val = avg_val_loss_list[-1]
epochs_loss_val = 1
# Save model
torch.save(model.state_dict(), 'model.pt')
# If we've been without losses decreasement for patience epochs
elif epochs_loss_val > patience:
print('{} epochs without loss reduction. Finishing training...'.format(epochs_loss_val - 1))
print('Final loss: {}'.format(round(lowest_loss_val, 5)))
break # Break training loop
epochs_loss_val += 1
I’ve been trying to find without success the reason why this is not behaving as expected. Is my implementation wrong? Can’t my model’s architecture benefit from Gradient Checkpointing? I would appreciate any help.
I’m running the code on Pytorch 1.31
Thank you in advance!!