Transfer Learning with ResNet50 Out of Memory Error Attempting to get 2048-element outputs


I am trying to extract 2048-element representations of fMNIST from a pre-trained resnet50. I need to store these vectors, presumably in a [60000,2048] data structure(currently trying tensors) in order to use them in subsequent experiments. My workflow is as below:

  1. Load fMNIST from torchvision
  2. Create dataloaders
  3. Loop through batches (for loop on a dataloder) extracting representations:
    - Inside the loop I call repeat_interleave to convert the gresycale channel to 3D (essentially copying the channel 3 times not doing any conversion)
    - create a temp variable storing those representations (batch_size, 2048)
    - send temporary output to cpu (because apparently it won’t even do 1 iteration on GPU)
    - store out in the final tensor.

After 4 iterations of batch size 64, a Tesla K4 from colab with 15GB crashes with out of memory error. Trying 1 iteration and checking memory with nvidia-smi shows 6.7 GB of GPU memory reserved. It might be naive but I would assume that this would be the memory needed at each iteration, but apparently it isn’t so?
How can I resolve this issue? What is the pytorch-way of generating such outputs from a pretrained model and storing them?

PS I have tried smaller batches and it works for a couple of extra iterations but it still crashes. I am assuming that intermediate temporary variables are created and not discarded at each loop. Also, tried to tackle this by calling:

del out

but to no avail…any help would be greatly appreciated. Thank you.

Based on the description it seems that you are storing tensors in e.g. a list, which are still attached to the computation graph. Thus you wouldn’t only store the tensors, but the entire graph with all intermediate tensors in each iteration.
If you don’t need to call backward() on the stored tensor or any tensor created from it, you could detach() the tensor before storing it in the list (or any other container).

1 Like

Thank you @ptrblck . I solved this by calling .detach().cpu().numpy() on the output of the model.