Pytorch LSTM: Target Dimension in Calculating Cross Entropy Loss

I’ve been trying to get an LSTM (LSTM followed by a linear layer in a custom model), working in Pytorch, but was getting the following error when calculating the loss:

Assertion cur_target >= 0 && cur_target < n_classes' failed.

I defined the loss function with:

criterion = nn.CrossEntropyLoss()

and then called with

loss += criterion(output, target)

I was giving the target with dimensions [sequence_length, number_of_classes], and output has dimensions [sequence_length, 1, number_of_classes].

The examples I was following seemed to be doing the same thing, but it was different on the Pytorch docs on cross entropy loss.

The docs say the target should be of dimension (N), where each value is 0 ≤ targets[i] ≤ C−1 and C is the number of classes. I changed the target to be in that form, but now I’m getting an error saying (The sequence length is 75, and there are 55 classes):

Expected target size (75, 55), got torch.Size([75])

I’ve tried looking at solutions for both errors, but still can’t get this working properly. I’m confused as to the proper dimensions of target, as well as the actual meaning behind the first error (different searches gave very different meanings for the error, none of the fixes worked).

Thanks

1 Like

Try to permute your output and target so that the batch dimension is in dim0, i.e. your output should be [1, number_of_classes, seq_length], while your target should be [1, seq_length].

4 Likes

My LSTM didn’t learn when I input this way into the criterion - when I stacked the batches, it worked…

                outputs = outputs.view(outputs.shape[1]*outputs.shape[0], outputs.shape[2]).cpu()
                target = target.view(target.shape[0]*target.shape[1],)
                loss = criterion(outputs, target)

1 Like

@ptrblck If I understand correctly, your suggested approach sums the losses over all timesteps of the sequence. I made an example with dummy tensor data to demonstrate this behavior

import torch
import torch.nn as nn

batch_size = 5
seq_len = 9
num_classes = 36

scores = torch.rand((batch_size, num_classes, seq_len))  # Model output
targets = torch.randint(0, num_classes, (batch_size, seq_len))
criterion = nn.CrossEntropyLoss(reduction='sum')

auto_sum = criterion(scores, targets)
manual_sum = 0
for timestep in range(scores.size(2)):
  score_t = scores[:, :, timestep]
  target_t = targets[:, timestep]
  manual_sum += criterion(score_t, target_t)
print(f'Manual {manual_sum:.6f} vs. auto {auto_sum:.6f}')

Would the auto_sum method be your preferred approach for loss calculation when training an encoder-decoder architecture for neural machine translation? Do you know of a preference among practitioners for summing across timesteps vs. averaging?

Same thing is happening with my model. Model is not learning when I have output as [batch size, number_of_classes, seq_length] , and target as [batch size, seq_length]

It is important for me to make sure that I take average over both timesteps and batch while calculating cross entropy loss. Is there any solution for this?

My model is not learning when I do what you suggested. On the other hand it learns, when I stack to have 2D output and 1D target tensor. Latter doesn’t allow me to take average over timesteps. Is there any solution?

Could you explain the use case, which learns fine, a bit more, i.e. how are you stacking the output and what shape does it and the target have?

When I have output as [batch size, vocab size, seq length] and my taregt as [batch size, seq length], the model does not learn. Doing this way is important for me since loss function in turn outputs [batch size, seq length] and then allows me to take average over both timesteps (i.e. seq length) and batch. Any idea why my model does not work in this case?

I tried sqeezing, where my output now is [batch size*seq length, vocab size] and target is [batch size*seq length]. In this approach my model learns but it does not allow me to take average over timesteps.

note: vocab size for me is nothing but the number of classes.

Are you using view() or reshape() to get the output into the required shape? If so, you might want to look at this post. In a nutshell, “carelessly” using view() or reshape() will messed up your output which very likely will lead to your network not learning properly.