[resolved] GPU out of memory error with batch size = 1

Hello,

I am taking my first steps in PyTorch, so I apologize in advance in case my issue is caused by some very stupid mistake from my own.

I was following the PyTorch tutorial on LSTMs (http://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html) and I have implemented the exercise suggested at the end (Augmenting the LSTM part-of-speech tagger with character-level features). I have used the part-of-speech dataset available in http://www.cnts.ua.ac.be/conll2000/chunking/ (discarding the third column, i.e. the chunking dataset). This dataset has over 7000 sentences.

Just like in the original code provided in the tutorial, training is being done using a single sentence at each iteration (i.e. the batch size is 1). Weirdly, around iteration 500, the following error appears:

THCudaCheck FAIL file=/b/wheel/pytorch-src/torch/lib/THC/generic/THCStorage.cu line=66 error=2 : out of memory Traceback (most recent call last): File "sequence_models_tutorial_complete.py", line 268, in <module> tag_scores = model(sentence_in, words_in) File "/home/dpernes/.anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 206, in __call__ result = self.forward(*input, **kwargs)

Using nvidia-smi, I can confirm that the occupied memory increases during simulation, until it reaches the 4Gb available in my GTX 970. I suspect that, for some reason, PyTorch is not freeing up memory from one iteration to the next and so it ends up consuming all the GPU memory available.

Here is the definition of my model:

class CharLvlRep(nn.Module):
  
  def __init__(self, embedding_dim, rep_dim, char_size):  
    super(CharLvlRep, self).__init__()
    
    self.char_embeddings = nn.Embedding(char_size, embedding_dim).cuda()
    self.lstm = nn.LSTM(embedding_dim, rep_dim).cuda()
    
  
  def forward(self, word, lstm_istate, is_train=False):
    embeds = self.char_embeddings(word)
    lstm_istate_var = (autograd.Variable(lstm_istate[0], requires_grad=False, volatile=not is_train),
                       autograd.Variable(lstm_istate[1], requires_grad=False, volatile=not is_train))
    char_reps, _ = self.lstm(embeds.view(len(word), 1, -1), lstm_istate_var)
    final_char_rep = (char_reps[char_reps.size()[0]-1, :, :])
    return final_char_rep

class LSTMTagger(nn.Module):

    def __init__(self, char_embedding_dim, char_rep_dim, char_size, word_embedding_dim, vocab_size, hidden_dim, tagset_size):
      super(LSTMTagger, self).__init__()
      self.char_rep_dim = char_rep_dim
      
      self.model_char = CharLvlRep(char_embedding_dim, char_rep_dim, char_size)
      self.word_embeddings = nn.Embedding(vocab_size, word_embedding_dim).cuda()
      
      # The LSTM takes word embeddings as inputs, and outputs hidden states
      # with dimensionality hidden_dim.
      self.lstm = nn.LSTM(word_embedding_dim + char_rep_dim, hidden_dim).cuda()
      
      # The linear layer that maps from hidden state space to tag space
      self.hidden2tag = nn.Linear(hidden_dim, tagset_size).cuda()

    def forward(self, sentence, words, word_lstm_istate, char_lstm_istate, is_train=False):
      word_embeds = self.word_embeddings(sentence)
      word_embeds = word_embeds.view(len(sentence), 1, -1)
      
      char_reps = autograd.Variable(torch.zeros(len(words),1,self.char_rep_dim).cuda(), volatile=not is_train)
      for idx, word in enumerate(words):
        char_reps[idx, :, :] = self.model_char(word, char_lstm_istate, is_train)
      
      embeds_cat = torch.cat((word_embeds, char_reps), dim=2)
      
      word_lstm_istate_var = (autograd.Variable(word_lstm_istate[0], requires_grad=False, volatile=not is_train),
                              autograd.Variable(word_lstm_istate[1], requires_grad=False, volatile=not is_train))
      
      lstm_out, _ = self.lstm(embeds_cat, word_lstm_istate_var)
      tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
      tag_scores = F.log_softmax(tag_space)
      return tag_scores

And here you have my training loop:

for epoch in range(N_EPOCH):
  is_train = True
  train_accuracy = 0.0
  for idx, (sentence, tags) in enumerate(train_data):
    # new training sequence -> zero the gradients of all models
    model.zero_grad()
    
    # one-hot encoding of chars in words
    words_in = []
    for word in sentence:
      words_in.append(prepare_sequence(word, char_to_ix, is_train))
    # one-hot encoding of words in sentence
    sentence_in = prepare_sequence(sentence, word_to_ix, is_train)
    
    # compute the scores (forward pass)
    tag_scores = model(sentence_in, words_in, WORD_ISTATE, CHAR_ISTATE, is_train)
    
    # one-hot encoding of the labels for each word
    targets = prepare_sequence(tags, tag_to_ix, is_train)
    
    # compute the accuracy
    train_accuracy += get_accuracy(tag_scores, targets)
    
    # compute the loss, gradients, and update the parameters by
    # calling optimizer.step()
    loss = loss_function(tag_scores, targets)
    loss.backward()
    optimizer.step()
    
    loss_hyst.append(loss.data[0])
    
    if idx % PRINT_EVERY == 0:
      print("It {}: loss = {}".format(idx,loss.data[0]))
  
  train_accuracy /= len(train_data)
  train_accuracy_hyst.append(train_accuracy.data[0])
  print("Epoch {}: train_accuracy = {}".format(epoch, train_accuracy.data[0]))
  
  # evaluate the validation accuracy after each epoch
  is_train = False
  valid_accuracy = 0.0
  for sentence, tags in valid_data:
    
    # one-hot encoding of chars in words
    words_in = []
    for word in sentence:
      words_in.append(prepare_sequence(word, char_to_ix, is_train))
    # one-hot encoding of words in sentence
    sentence_in = prepare_sequence(sentence, word_to_ix, is_train)
    
    # compute the scores (forward pass)
    tag_scores = model(sentence_in, words_in, WORD_ISTATE, CHAR_ISTATE, is_train)
    
    # one-hot encoding of the labels for each word
    targets = prepare_sequence(tags, tag_to_ix, is_train)
    
    # compute the accuracy
    valid_accuracy += get_accuracy(tag_scores, targets)
    
  valid_accuracy /= len(valid_data)
  valid_accuracy_hyst.append(valid_accuracy.data[0])
  print("Epoch {}: valid_accuracy = {}".format(epoch, valid_accuracy.data[0]))
  
  # save the best model so far
  if valid_accuracy.data[0] > best_valid_accuracy:
    torch.save(model.state_dict(), SAVE_PATH)
    best_valid_accuracy = valid_accuracy.data[0]

where the function prepare_sequence() is defined as:

def prepare_sequence(seq, to_ix, is_train=False):
  idxs = [to_ix[w] for w in seq]
  tensor = torch.LongTensor(idxs).cuda()
  # if not in training mode, return a volatile variable (no backward pass)
  return autograd.Variable(tensor, requires_grad=False, volatile=not is_train)
2 Likes

Just found the issue! My function get_accuracy() was returning a variable accuracy instead of the tensor accuracy.data. Since the return value of this function is accumulated in every training iteration (at train_accuracy += get_accuracy(tag_scores, targets)), the memory usage was increasing immensely.

I replaced return accuracy by return accuracy.data[0] in the function get_accuracy() and the problem is solved!

6 Likes

@dpernes I am facing a similar issue, but I am not quite able to understand the logic why it is failing? I have implemented a custom regularizer, which calculates the infinity norm of the weights of the model and then adds it to the existing loss.

def l2_rego(mdl):
        l2_reg = None
        for W in mdl.parameters():
                if W.ndimension() < 2:
                        continue
                else:   
                        wt = torch.transpose(W,0,1)
                        m  = torch.matmul(wt,W)
                        ident = Variable(torch.eye(cols,cols),requires_grad=True)
                        ident = ident.cuda()
                        w_tmp = (m - ident)
                        if l2_reg is None:
                                l2_reg = ((torch.max(torch.abs(w_tmp)))**2)
                        else:   
                                l2_reg = l2_reg + ((torch.max(torch.abs(w_tmp)))**2)
                        #call("nvidia-smi")
        return l2_reg
def train(train_loader, model, criterion, optimizer, epoch,odecay):
    """Train for one epoch on the training set"""
    for i, (input, target) in enumerate(train_loader):
        target = target.cuda(async=True)
        input = input.cuda()
        input_var = torch.autograd.Variable(input)
        target_var = torch.autograd.Variable(target)

        # compute output
        output = model(input_var)

        #Compute Loss
        oloss =  l2_rego(model)
        loss = criterion(output, target_var)
        loss = loss + oloss

loss and oloss here is of type Variable. Within the first epoch I am getting
memory error.RuntimeError: cuda runtime error (2) : out of memory at /pytorch/torch/lib/THC/generic/THCStorage.cu:58

I fear, I am also hitting the same issue, but not able to quite deduce.

Hi @nbansal90,

I couldn’t find the error by simply looking at your code… Have you already tried removing the call to the regularization function from your training loop? If the issue disappears, then there’s something in that function. If it persists, the problem must be somewhere else…