[resolved] Holy memory balloons batman!

Greetings all!

I am fairly new to pytorch, but I have experience running very large models in caffe, keras, theano, and others. My system has 3 GPUs and I am running some tests on my Quadro GP100 at the moment (16.3 GB video ram). I wrote the below model to see how memory behaves in pytorch. Probably should have started smaller. Ah well, go big or go home!

Anyway, The following code crashes as a result of running out of memory on the Quadro (holy crap).

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
from torch.autograd import Variable
 
import numpy as np
import math
import time
import random

import matplotlib.pyplot as plt

def DimCalcConv2D(Ishape, N, Kshape, Sshape, Pshape):
  Dimx = N

  Dimy = int(math.floor((Ishape[0] + 2*Pshape[0] - Kshape[0])/Sshape[0] + 1))
  Dimz = int(math.floor((Ishape[1] + 2*Pshape[1] - Kshape[1])/Sshape[1] + 1))

  return (Dimx, Dimy, Dimz)
  

class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator,self).__init__()
    
    self.ishape = (1,256,256)
    
    self.conv1 = nn.Conv2d(1, 64, 3, stride=1, padding =0)
    self.conv2 = nn.Conv2d(64,64,3, stride=1, padding=0)
    self.conv3 = nn.Conv2d(64,64,3,stride=1, padding=0)

    D1 = DimCalcConv2D((256, 256),64,(3,3),(1,1),(0,0))
    D2 = DimCalcConv2D(D1[1:],64, (3,3),(1,1),(0,0))
    D  = DimCalcConv2D(D2[1:],64, (3,3),(1,1),(0,0))

    self.ConvOutShape = D
    N = int(D[0]*D[1]*D[2])

    self.linear1 = nn.Linear(N,256)
    self.linear2 = nn.Linear(256,256)
    self.linear3 = nn.Linear(256,2)

    self.train()
    

  def forward(self, Image):
    x = F.relu(self.conv1(Image))
    x = F.relu(self.conv2(x))
    x = F.relu(self.conv3(x))

    D = self.ConvOutShape
    N = int(D[0]*D[1]*D[2])
    x = x.view(-1,N)

    x = F.relu(self.linear1(x))
    x = F.relu(self.linear2(x))
    x = F.relu(self.linear3(x))

    return F.softmax(x)
    
if __name__ == "__main__":

  image = np.ones((1,1,256,256))
  
  Image = Variable(torch.from_numpy(image).float()).cuda(0)
  Truth = Variable(torch.from_numpy(np.array([1,0]).reshape(1,2)).float()).cuda(0)
  
  net = Discriminator()
  net.cuda(0) #4.4 GB video ram.  With float32 this is expected for this model
  net.train()
  
  optimizer = optim.Adam(net.parameters())
  BCE = torch.nn.BCELoss()
  
  out = net(Image)
  loss = BCE(out, Truth)
  loss.backward() #12.5 GB video ram.  Gradient buffer populated?  Why isn't this 8.8 GB? one grad per parameter...
  
  optimizer.step() # CRASH.  out of ram.  16.3 GB video ram

I have run much, MUCH larger models than this on my current system with other libraries (VGG 18, for example) and haven’t had many memory issues. At the moment I am assuming this is because other libraries do some kind memory management and distribution across multiple GPUs behind the scenes, while pytorch (I hope) leaves this to the user. This raises some questions.

  1. Did I do something dumb with this example 4.4 GB model that somehow causes it to balloon to 16+ GB? If it’s a simple fix, what is it? (note, by dumb I mean is there something I missed in pytorch. The model is purposefully huge so changes in memory are easier to distinguish)

  2. Why does the memory increase during optimizer.step()? The backward call makes sense to me as you need to populate the gradient buffer, but doesn’t step() just use the buffer to update the parameters? At most I would expect this to increase by another 4.4 GB for temp variables for each parameter.

  3. If I didn’t do something dumb, is there a good “rule of thumb” for predicting how much memory a model will occupy on a card during training?

  4. is there a quick / easy way to tell pytorch “hey, I’ve got these two other GPUs. You should use them intelligently!”? How do I distribute my model across my other GPUs?

Thanks!
Gus

2 Likes

I’m sorry for the late reply. This last week has been hectic.
I’ve made a note to look into this once I get back to the office on Wednesday.

1 Like

No worries! I’m just still learning this. I wouldn’t be surprised if it’s something trivial i’m missing =D Thanks!

Dumb answer, but Adam has to keep a running average and squared running average of the gradients, so are you accounting for this as well?

2 Likes

ah! Interesting!

Yes that indeed makes a difference. Using RMSprop instead of Adam reduces the total allocated vram. 11.8 GB at loss.backwards() and 15.6 GB at optimizer.step(). So the optimization algorithm makes a difference.

Still not sure why optim.step() shows such a large increase though :confused: still! good to know! Any idea how to distribute such a large model across multiple GPUs easily? My initial thought was to manually place parts of the model on different devices and transfer the activation from device to device in forward(), but I am curious if there is a more elegant way.

Thanks!

Different optimisers require different amounts of memory. Normal SGD only involves the gradients, but if you want to store the momentum of the weights this is going to double the memory requirement because of the optimisation step. If you look at the Adam code, you’ll see it literally allocates the two buffers I mentioned at the first call of optim.step(), so different bits can be allocated at different parts of the overall optimisation. I’ll have to leave practices of model parallelism to someone with more than one GPU though~

1 Like

I was about to do the same kind of question, when I saw this post.

I had (relatively) small model built in Keras (Theano backend) that only took ~200-300 Mb of GPU memory. However, the same model implementation in pytorch takes 2200Mb of GPU memory. I’m using Adam, but then again I was using it in Keras as well…

1 Like

miguelvr in your case it’s probably different and also prob the memory measurement is not accurate because of the caching allocator

1 Like

Michael, you can do model parallelism simply by using the keyword with torch.cuda.device(…) as described in the link.

1 Like