CPU RAM explodes after 30 minibatch iterations

I have read over 10 different reports of a similar problem with exploding RAM on a CPU and none of them have worked.

I can run my forward pass using just the encoder network without any problems so I know my DataLoader and Encoder are fine. If I use the Decoder just for forward passes, without even storing or computing any losses, the RAM explodes.

Therefore, what is it about my decoder that causes this? I tried to follow this https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html official Pytorch tutorial with the Attention based decoder as much as possible. The only difference with my code is that my Encoder processes the whole minibatch first rather than a single example so I have a double for loop (loop through the batch elements and then each time step of them).

Here is my code to run the forward passes with a randomly created latent space (normally comes from the Encoder):

device = torch.device("cpu")
import torch.optim as optim
epochs = 5
learning_rate=0.01
# initialize the NN. 
decoder_net = DecoderNet().to(device)

dataloader = DataLoader(df, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, collate_fn=pad_and_sort_batch)

mini_batch_iters=0

# iterate through epochs
for e in range(1,(epochs+1)): 
   
    # counter for the number of minibatches passed
    mini_batch_iters=0
    
    # randomly generated latent space, minibatch size of 32, length 200 latent space vector
    latent = torch.randn((32,200))
    
    # iterate through data
    for x in enumerate(dataloader):
      
      #finds the longest sequence in the batch
      max_l = torch.max(x[1][1])

      #creating variables to hold outputs: 
      batch_outs = torch.zeros([BATCH_SIZE,max_l, VOCAB_SIZE], device=device, requires_grad=True)
      
      #iterate through the minibatch:
      for length, b_ind in zip(x[1][1], range(BATCH_SIZE)):
        #print('batch ind', b_ind)
        #print('length of example', length)
      
       #more holder variables including initializing hidden inputs to the LSTM:
        seq_outs= torch.zeros([length, VOCAB_SIZE], device=device, requires_grad=True)
        prev_out = torch.zeros([VOCAB_SIZE], device=device, requires_grad=False)
        hidden = (torch.zeros([2,1,int(CODE_LAYER_SIZE/2)], device=device, requires_grad=False),
                  torch.zeros([2,1,int(CODE_LAYER_SIZE/2)], device=device, requires_grad=False))
        hidden_out = (torch.zeros([1,1,VOCAB_SIZE], device=device, requires_grad=False),
                      torch.zeros([1,1,VOCAB_SIZE], device=device, requires_grad=False))

        # For each element in the batch, I run it through 2 LSTMs at each timestep
        for t in range(length):
          prev_out, hidden, hidden_out = decoder_net(latent[b_ind,:], prev_out, hidden, hidden_out )
         
          seq_outs[t,:]= prev_out
          
      # add each sequence to the batch data collector 
        batch_outs[b_ind,:,:].add_( torch.cat( ( seq_outs, torch.zeros([ (max_l-length) ,VOCAB_SIZE] , device=device, requires_grad=False) ) , 0))

        
      mini_batch_iters+=1
      print('mini batches iters', mini_batch_iters)

After 30 minibatches my CPU RAM explodes to over 13GB (starting at less than 2GB) and linearly increasing.

Here is my simple DecoderNet:

BATCH_SIZE = 32
CODE_LAYER_SIZE=200
VOCAB_SIZE=21

class DecoderNet(nn.Module):

    def __init__(self):
        super(DecoderNet, self).__init__()
        #decode
        #self.dense1_predecode = nn.Linear(in_features=(CODE_LAYER_SIZE), out_features=50 )
        self.decoder = nn.LSTM(input_size=(CODE_LAYER_SIZE+VOCAB_SIZE),bidirectional=False, hidden_size= int(CODE_LAYER_SIZE/2),num_layers=2)
        self.decoder_out = nn.LSTM(input_size=int(CODE_LAYER_SIZE/2),bidirectional=False, hidden_size= VOCAB_SIZE,num_layers=1)
        # may want to do bilinear times 2 as it is bidirectional 
        #self.dense1_dec = nn.Linear(in_features=(CODE_LAYER_SIZE*2), out_features=50 )
        #self.dense2_dec = nn.Linear(in_features=50, out_features=20 )

    def forward(self, latent_space, prev_out, hidden, hidden_out):
        #decoding. takes in a single latent code from a single part of the batch. 
        # where input is the teacher forcing or the predictions from the previous step. 
        
        prev_out, hidden = self.decoder( torch.cat( (latent_space, prev_out ), 0 ).view(1, 1, -1), hidden) 
        prev_out, hidden_out = self.decoder_out(prev_out, hidden_out)
        prev_out = F.log_softmax(prev_out, dim=2)

        return prev_out.squeeze_(), hidden, hidden_out

I don’t get what is wrong. It must have to do with me storing the outputs from each timestep and minibatch entry? However, the PyTorch official tutorial https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html does something very similar.

UPDATE: While waiting for an answer I have run the Decoder as a whole batch without the for loops (this isn’t what I ultimately want) but it worked without any memory issues. Clearly there is something wrong with how I am using my for loops and storing intermediate computations.

Thank you in advance for any help!

In my understanding, from your code, I can see that you are not using any loss function to back propagate and just running the forward calls for 30 mini-batches. If thats the case, the computation graphs are never cleared out of your RAM and they are still residing in the RAM. If you do not wish to back propagate, call the models under with torch.no_grad() decorator, which will avoid creation of computation graphs.

4 Likes

Thank you for your comment. However, when I did have a loss function and perform backprop the RAM also exploded. I was only trying to isolate the issue by removing the loss.

You know, I went to get my code with the loss still in it and ended up finding the real bug! I was only calling zero_grad() on my optimizer every epoch rather than every minibatch… Dumb mistake.

But it makes sense why keeping track of so many little gradients from a stepwise approach minibatch after minibatch made the RAM explode before the next epoch could reset it.

Its actually fortunate this happened as otherwise I wouldn’t have spotted the bug and would have been doing backprop every minibatch on gradients calculated from previous minibatches I had already backpropped on!

Thank you for your help getting me to spot what was actually wrong!

1 Like

another thing that can make your memory explote is if you want to track the accumulate loss of the training set without optimizing and you do not call .data or torch.no_grad(). For instance

COST=Loss(y,net(x))
acc_cost+=COST #do not do this
acc_cost+=COST.data # do this

calling .data is deprecated, however for this little thing I still find it the easiest and simple way.

3 Likes