Exploding memory

I am facing an issue where my memory usage is exploding, and I can’t explain why. Here is what my code looks like

modes = torch.randn(num_models, 1, 512, 30522).to(device)

for epoch in range(num_epochs):
    model.train()

    for batch_ix, batch in enumerate(train_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch) 
        
        distances = torch.cat([((outputs.logits - mode)**2).sum(axis=[1,2]) for mode in modes])

        chosen_model_ix = torch.argmin(distances)

        #update the new mode
        modes[chosen_model_ix] += outputs.logits
        modes[chosen_model_ix] /= 2

After playing with torch.no_grad(), it seems these lines are making the most difference to the GPU memory footprint

#update the new mode
modes[chosen_model_ix] += outputs.logits
modes[chosen_model_ix] /= 2

Does this make sense? Why would these inplace operations blow up my memory?

1 Like

It seems you are accumulating tensors which are still attached to a computation graph and are thus disallowing PyTorch to free any intermediate tensors.
I don’t know if modes should stay differentiable, but in case that’s not the case and it’s used as a static target/reference, you should .detach() the logits before accumulating them.

1 Like

Thanks, that seemed to curb the memory consumption. I"ll follow up with further queries if that didn’t satisfy my research problem.

Hey, just to reopen this issue, is there any reason that calculating accuracy should cause GPU memory usage to blow up?

This is the relevant part of my code- I’ve narrowed the bug down to the line that calculates accuracy- this seems to be the part of the code that’s blowing up the GPU memory.

model_cx_accuracy = Accuracy(task="multiclass", num_classes = len(models)).to(device)
mod_pred_class = model_ops.logits.argmax(dim = -1)
mod_mo_acc = gen_accuracy(mod_pred_class, batch['labels']) # Buggy line

Does this make sense?

It depends what exactly gen_accuracy does so did you check which tensors are allocated inside this function and how large these are?

whoops, I intended to paste this earlier but forgot-

gen_accuracy = Accuracy(task="multiclass", num_classes=30522, ignore_index=-100, average = 'micro', multidim_average = 'samplewise').to(device)

the tensors are 8 * 512 * 30522 each, but i don’t see why that space isn’t being allocated and freed each time gen_accuracy is being called

The shape wouldn’t make sense for amulticlass classification as the class dimension is expected to be in dim1. In any case, using your shapes the method seems to use ~500MB:

import torch
from torchmetrics import Accuracy

print(torch.cuda.memory_allocated()/1024**2)
# 0.0

device = "cuda"
gen_accuracy = Accuracy(task="multiclass", num_classes=30522, ignore_index=-100, average = 'micro', multidim_average = 'samplewise').to(device)

x = torch.randn(8, 30522, 512, device=device)
y = torch.randint(0, 30522, (8, 512), device=device)
print(torch.cuda.memory_allocated()/1024**2)
# 476.9375

loss = gen_accuracy(x, y)
print(torch.cuda.memory_allocated()/1024**2)
# 484.39111328125
1 Like

Yes, I think I argmax over the last dimension and then feed in two 8 x 512 vectors.

But, this doesn’t explain why the accuracy is exploding, right? It seems this accuracy costs 9 MB or so MB every time it’s called. I think that’s what’s costing me to run out of memory.

for i in range(100):
  loss = gen_accuracy(preds, y)
  print(torch.cuda.memory_allocated()/1024**2)

Is there anything I can do to prevent my accuracy function hogging a new memory chunk memory each time it’s called?

it’s interesting that using the Accuracy() function blows up memory, but manually implementing it, doesn’t.

import torch
from torchmetrics import Accuracy

print(torch.cuda.memory_allocated()/1024**2)
# 0.0

device = "cuda"


x = torch.randn(8, 512, 30522, device=device)
preds = x.argmax(dim = -1)
y = torch.randint(0, 30522, (8, 512), device=device)
print(torch.cuda.memory_allocated()/1024**2)
# 476.96875

for i in range(100):
  gen_accuracy = Accuracy(task="multiclass", num_classes=30522, ignore_index=-100, average = 'micro', multidim_average = 'samplewise').to(device)
  acc1 = torch.sum((preds == y).float(), dim = -1) #doesn't blow up memory
  acc2 = gen_accuracy(preds, y) #does blow up memory
  print(torch.cuda.memory_allocated()/1024**2)

Any thoughts on whether I’m missing something?

I don’t think you are missing anything and I can also reproduce the increased memory usage.
However, I needed to jump on another issue and didn’t debug the issue yet.
So far I would blame the Accuracy method and would use a manual implementation.

1 Like

ok, let me also report this on torchmetrics’ github, see if anyone has any thoughts