Hi guys, I trained my model using pytorch lightning. At the beginning, GPU memory usage is only 22%. However, after 900 steps, GPU memory usage is around 68%.
Below is my for training step. I tried to remove unnecessary tensor and clear cache. And I did one for loop check. Everything works fine. May I know where could be the potential issue to cause this memory usage increase?
def training_step(self, batch, batch_idx):
self.model.train()
self.model.apply(set_bn_eval)
inputs = batch['image']
labels = batch['label']
inputs, labels = inputs.to(self.hparams.device), labels.to(self.hparams.device)
_, di = self(inputs)
scores = torch.mm(di, di.t())
label_matrix = lib.create_label_matrix(labels)
if hasattr(self, 'memory'):
memory_embeddings, memory_labels = self.memory(di.detach(), labels, batch["path"])
if self.current_epoch >= self.memory.activate_after:
memory_scores = torch.mm(di, memory_embeddings.t())
memory_label_matrix = lib.create_label_matrix(labels, memory_labels)
logs = {}
losses = []
total_loss = 0
for crit, weight in self.criterion:
if hasattr(crit, 'takes_embeddings'):
loss = crit(di, labels.view(-1))
if hasattr(self, 'memory'):
if self.current_epoch >= self.memory.activate_after:
mem_loss = crit(di, labels.view(-1), memory_embeddings, memory_labels.view(-1))
else:
loss = crit(scores, label_matrix)
if hasattr(self, 'memory'):
if self.current_epoch >= self.memory.activate_after:
mem_loss = crit(memory_scores, memory_label_matrix)
loss = loss.mean()
if weight == 'adaptative':
losses.append(loss)
else:
losses.append(weight * loss)
logs[crit.__class__.__name__] = loss.item()
if hasattr(self, 'memory'):
if self.current_epoch >= self.memory.activate_after:
mem_loss = mem_loss.mean()
if weight == 'adaptative':
losses.append(self.memory.weight * mem_loss)
else:
losses.append(weight * self.memory.weight * mem_loss)
logs[f"self.memory_{crit.__class__.__name__}"] = mem_loss.item()
self.log(crit.__class__.__name__, loss.item(), on_step = True, on_epoch = True, prog_bar = True, logger = True)
total_loss = sum(losses)
self.log("total_loss", total_loss.item(), on_step = True, on_epoch = True, prog_bar = True, logger = True)
del losses, scores, label_matrix, di, labels, inputs, loss
torch.cuda.empty_cache()
return total_loss