My code runs into a Cuda OOM error when I am accumulating loss inside a loop. I understand that this due to the computational graph growing with each iteration. Given, that the inputs are images, this would be problematic. I have also added ‘del’ statements to manually free memory, but that still does not help and I run into the CUDA OOM issue within a few iterations inside the loop. Is there any way to get this running or do I need to alter my approach of accumulating loss inside a loop? I have attached the code below
Any of the functions not provided are simple functions for formatting or calculating the Cross entropy, which are straightforward and do not seem to be the cause of the CUDA OOM error
def get_dist_lp(self, batch, format_batch_func):
batch = format_batch_func(batch)
obs = torch.cat((batch["obs_1"], batch["obs_2"]), dim=0)
action = torch.cat((batch["action_1"], batch["action_2"]), dim=0)
del batch
obs = self.network.encoder(obs)
dist = self.network.actor(obs)
del obs
if isinstance(dist, torch.distributions.Distribution):
lp = dist.log_prob(action)
else:
assert dist.shape == action.shape
# For independent gaussian with unit var, logprob reduces to MSE.
lp = -torch.square(dist - action).sum(dim=-1)
return lp
# For now only works on th e train data, where we get a tuple of batch and perturbed_labels
def get_avg_likelihood(self, train_dataloader, format_batch_func):
tot_cpl_ll = 0.
tot_correct_preds = 0
data_size = 0.
mini_batch_cnt = 0
for i, (batch, labels_perturbed) in enumerate(train_dataloader):
labels_perturbed = labels_perturbed.to(self.device)
lp = self.get_dist_lp(batch, format_batch_func)
# Compute the advantages.
adv = self.alpha * lp
segment_adv = adv.sum(dim=-1)
adv1, adv2 = torch.chunk(segment_adv, 2, dim=0)
# cpl_ll is of shape (num_dataset_perturbations,)
cpl_ll, correct_preds = biased_bce_with_logits_perturbed(adv1, adv2, labels_perturbed, bias=self.contrastive_bias)
tot_cpl_ll = tot_cpl_ll + cpl_ll
tot_correct_preds += correct_preds
data_size += labels_perturbed.shape[0]
mini_batch_cnt += 1
del batch, labels_perturbed
avg_cpl_ll = tot_cpl_ll/data_size
accuracy_all_ds = tot_correct_preds/data_size
return avg_cpl_ll, accuracy_all_ds
def CVaR(self, ll_tensor, cvar_alpha=0.95):
sigma = ll_tensor.clone()
cvar = torch.max(sigma - torch.mean(F.relu(sigma[:,None] - ll_tensor), dim=1)/(1-cvar_alpha))
return cvar
def train_step_perturbed(self, train_dataloader, format_batch_func, cvar_alpha=0.95):
self.optim["actor"].zero_grad()
avg_ll_cpl, accuracy_all = self.get_avg_likelihood(train_dataloader, format_batch_func)
avg_accuracy = accuracy_all.mean()
cvar_ll = self.CVaR(avg_ll_cpl, cvar_alpha)
loss = -cvar_ll
loss.backward()
self.optim["actor"].step()
del avg_ll_cpl, accuracy_all
print(f"Loss: {loss.item():.4f} \n")
return dict(cpl_loss=loss.item(), bc_loss=0., accuracy=avg_accuracy.item())