Calculating loss for entire batch using NLLLoss in 0.4.0

Hello! I’m trying to move to 0.4.0 and improve sequence to sequence model performance. And I’m stuck at loss calculating. Sorry for my poor English… I’ll try to explain my problem.

Earlier on 0.3 version I was running single “dataset-unit” through model and then calculating loss. Dataset-unit is a pair of 2 tensors: input sentence and target-sentence + target indexes of words from vocabulary.

Input-sentence: torch.Size([13, 1, 100])
Target-sentence: torch.Size([3, 1, 100])
Target-indexes: torch.Size([3])

Where 100 is embedding size, 13 and 3 are amounts of embedded words.

Model:

class SeqModel(nn.Module):
  def __init__(self, input_size, hidden_size, output_size, layers):
    super(SeqModel, self).__init__()

    self.encoder = nn.GRU(input_size, hidden_size, layers)
    self.decoder = nn.GRU(input_size, hidden_size, layers)
    self.out = nn.Linear(hidden_size, output_size)

    self.hidden_state = None
    self.start_hidden = Variable(torch.zeros(layers, hidden_size))
    if USE_CUDA:
      self.start_hidden = self.start_hidden.cuda()

  def encode(self, embedded_sentence):
    outputs, self.hidden_state = self.encoder(embedded_sentence,  self.start_hidden)
    return outputs

  def decode(self, inputs):
    outputs_h, self.hidden_state = self.decoder(inputs, self.hidden_state)
    outputs = F.log_softmax(self.out(outputs_h.squeeze(0)), dim=-1)
    return outputs
    
  def forward(self, embedded_sentence, reply_vectors):
    enc_outs = self.encode(embedded_sentence)
    dec_outs = self.decode(reply_vectors)
    return dec_outs

And calculating loss was like this:

def train_iter(model, criterion, data):
  loss = 0
  loss_val = 0
  for phrase, reply in data:
    data_enc_in = phrase['vectors']   # =  input-sentence, e.g. torch.Size([13, 1, 100]) 
    data_dec_in = reply['vectors']   # =  target-sentence, e.g. torch.Size([3, 1, 100])
    data_target = reply['indexes']   # =  target-indexes, e.g. torch.Size([3])
    
    outputs = model(data_enc_in, data_dec_in)
    loss = criterion(outputs.squeeze(1), data_target)
    loss.backward()
    loss_val += loss.data[0]
  return loss_val / len(data)

And after several such iterations I was calling opt.step()

But now I want to pass through the model the whole batch to improve performance. Model is slightly different:

class SeqModel(nn.Module):
  def __init__(self, input_size, hidden_size, output_size, layers):
    super(SeqModel, self).__init__()

    self.encoder = nn.GRU(input_size, hidden_size, layers)
    self.decoder = nn.GRU(input_size, hidden_size, layers)
    self.out = nn.Linear(hidden_size, output_size)

    self.hidden_state = None
    self.start_hidden = torch.zeros(layers, hidden_size).to(device)

  def encode(self, embedded_sentence):
    h_stack = [self.start_hidden for h in range(embedded_sentence.shape[1])]
    h0 = torch.stack(h_stack, dim=1)
    
    outputs, self.hidden_state = self.encoder(embedded_sentence, h0)
    return outputs

  def decode(self, inputs):
    outputs_h, self.hidden_state = self.decoder(inputs, self.hidden_state)
    outputs = F.log_softmax(self.out(outputs_h.squeeze(0)), dim=-1)
    return outputs
    
  def forward(self, embedded_sentence, reply_vectors):
    enc_outs = self.encode(embedded_sentence)
    dec_outs = self.decode(reply_vectors)
    return dec_outs

And the dataset_unit contains:
Input-sentence: torch.Size([13, 100])
Target-sentence: torch.Size([3, 100])
Target-indexes: torch.Size([3])

I removed second dimension to allow tensors stacking.

Now I’m trying to do the same demo-train-iteration but for whole batch at once:

seq_simple_test = SeqModel(EMBEDDING_SIZE, 100, VOCAB_SIZE, 1).to(device)

# Batch of input sentences, torch.Size([13, 100])
inp = torch.stack([test_unit[0]['vectors'] for test_unit in test_batch], dim=1)
# Batch of target sentences, torch.Size([3, 100])
dinp = torch.stack([test_unit[1]['vectors'] for test_unit in test_batch], dim=1) 

print(inp.size(), dinp.size())
# ==> torch.Size([13, 2, 100]) torch.Size([3, 2, 100])
# Where 2 is a batch size

dec_outs = seq_simple_test.forward(inp, dinp)

print(dec_outs.size())
# ==> torch.Size([3, 2, 35620])
# Where 2 is a batch size and 35620 is a vocabulary (and softmax) size

opt_test = optim.Adam(seq_simple_test.parameters(), lr=0.001)
criterion_test = nn.NLLLoss()

for i in range(len(test_batch)):
  replies = torch.index_select(dec_outs, 1, torch.LongTensor([i]).to(device)).squeeze(1)
  print(replies.size())
  # ==> torch.Size([3, 35620])
  target_reply = test_batch[i][1]['indexes']
  print(target_reply.size())
  # ==> torch.Size([3])
  loss = criterion_test(replies, target_reply)
  loss.backward()

opt_test.step()

And I’m getting:

Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

I tried solutions from other threads. For example I can call loss.backward(retain_graph=True), but recalculating graph will decrease my model’s performance (Tried. This takes 7 minutes on 2000-size batch).
I’ve also tried:

loss = 0
for i in range(len(test_batch)):
  replies = torch.index_select(dec_outs, 1, torch.LongTensor([i]).to(device)).squeeze(1)
  target_reply = test_batch[i][1]['indexes']
  loss += criterion_test(replies, target_reply)
loss.backward()

This works, but I’ve read that loss+= criterion(…) is not recommended, because it increases memory usage. I’ve tried this method on batch=2000 samples:

First time:
Ran batch througn model: 37 milliseconds
Calculating loss for 1minute, 24 seconds
(1 minute, 24 seconds means ~24 samples in a second. This is 4 times slower than I had earlier)

At the second time it falls with:

cuda runtime error (2) : out of memory at /pytorch/aten/src/THC/generic/THCStorage.cu:58
(I’m using Google Colab)

I’ve also tried to pass whole batch results to criterion, but it seems that NLLLoss doesn’t support it.

So, how do I calculate the loss to keep performance high? Thank you in advance!

1 Like

You have to reset gradients with optimizer.zero_grad() before calling again loss.backward()

It doesn’t make sense. I need to accumulate gradient.
For example:

I’m sorry, I said a stupid thing. Are you reinitializing the hidden states between samples yet? If you don’t want to reinizitialize them, have you tried to detach the hidden states between batches? Usually you don’t need the entire history of the hidden states between one example and one other.

When a new tensor is passed to model.encode() function (or to model.forward()) then self.hidden_state become reinitialized.

reinitializing the hidden states between samples

Earlier, when I passed one training sample to the model at a time, the hidden_state was reinitialized between samples each time.
But now I pass whole batch to the model and there’s no need to reinitialize hidden_state between samples.
The problem I have is I need to calculate loss for the whole batch. When I’m trying to pass my decoder outputs and my target indexes (torch.Size([3, 2, 35620]) and torch.Size([3, 2])) to NLLLoss, I get an error about inconsistent tensor sizes. It seems NLLLoss asks me to use one-hot encoding for computing loss. But one-hot encoding will greatly increase memory usage, that makes impossible to use big batches in Google Colab.

NLLLoss doesn’t ask for one-hot encoding. Input to the NLLLoss is (N,C) where C is the number of classes and the target is (N). Are you inputting elements one-by-one and accumulating the loss for each element? Can you recode so that you can input larger batches instead of batch size of 1? Accumulating loss for each element will surely result in a huge memory requirement.

Can you recode so that you can input larger batches instead of batch size of 1?

This is my question: how to recode it to input the entire batch instead of calculating loss one by one.
If I try to input the batch to the criterion, I get:

ValueError: Expected target size (7, 35620), got torch.Size([7, 50])

I’m inputting torch.Size([7, 50, 35620]) as decoder outputs and torch.Size([7, 50]) as target indexes.
Where 50 is a batch_size, 35620 – vocabulary size (number of classes), and 7 is a number of predicted words.

It works when I pass it one by one (torch.Size([7, 35620]) with torch.Size([7])), and I don’t know how to adapt it for the batch input.

EDIT:
OMG, I tried “batch_first” way earlier and this didn’t work. But I didn’t try to put batch_size to the last position like torch.Size([7, 35620, 50]) and torch.Size([7, 50]). And this works now. Is this correct way to compute loss?

opt.zero_grad()
for i in range(iterations):
  ...
  loss = criterion_test(dec_outs.view(-1, vocab_size, batch_size), targets.view(-1, batch_size))
  loss.backward()
## after few iterations
opt.step()
1 Like

I think you need to do criterion_test(dec_outs.view(-1, vocab_size),targets.view(-1))
In your case, ( C )-> vocab_size and (N)-> (batch_size*seq_length). I am assuming all the batches have the same sequence length. If not, you’ll have to use pack_padded_sequence and also mask the loss for the pad token.

4 Likes

Yes, you are right! The training speed of your calculating method is comparable with that I tried, but loss decreasing is much faster.
Speed now is just mind-blowing. 30 times faster that I had earlier. Thank you!!