Loss function and LSTM dimension issues

Hi all, I am writing a simple neural network using LSTM to get some understanding of NER. I understand the whole idea but got into trouble with some dimension issues, here’s the problem:

class NERModel(nn.Module):
    """
    Encoder for NER model.
    Args:
        - vocab_size: vocabulary size, integer.
        - embedding_size: embedding size, integer.
        - enc_units: hidden size of LSTM layer, integer.
        - ffc_units: hidden units of feedforward layer, integer.
        - num_labels: number of named entities. The value should be (actual_num_labels + 1),
            because zero paddings are added to the sequences.
    """
    
    def __init__(self, vocab_size, embedding_size, enc_units, ffc_units, num_labels):
        super(NERModel, self).__init__()
        # Word embedding layer.
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        # LSTM layer with units of enc_units
        self.LSTM = nn.LSTM(embedding_size, enc_units, batch_first=True)
        self.dense1 = nn.Linear(enc_units, ffc_units)
        self.dense2 = nn.Linear(ffc_units, num_labels)
        
    def forward(self, x):
        """
        Args:
            - x: Input tensor of shape (batch_size, sequence_length)
        Return:
            Tensor of shape (batch_size, sequence_length, num_labels)
        """
        x = self.embedding(x)
        # after embedding: torch.Size([64, 124, 256])
        x, _ = self.LSTM(x)
        # after lstm: torch.Size([64, 124, 256])
        x = self.dense1(x)
        # after linear 1: torch.Size([64, 124, 256])
        x = self.dense2(x)
        # after linear 2: torch.Size([64, 124, 6])
        output = F.log_softmax(x, dim=1)
        # after softmax: torch.Size([64, 124, 6])
        return output

# initialize model
model = NERModel(vocab_size, embedding_size, enc_units, ffc_units, num_labels)

optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
criterion = nn.CrossEntropyLoss()

for i, (value, label) in enumerate(train_loader):
        print(value.shape)
        optimizer.zero_grad()
        outputs = model(value)
        # outputs shape: torch.Size([64, 124, 6])
        # label shape: torch.Size([64, 124])
        loss = criterion(outputs, label)

Things all looked good but I got the following error reported:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-9-92530e221aaf> in <module>()
---> 14         loss = criterion(outputs, label)

~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
    940     def forward(self, input, target):
    941         return F.cross_entropy(input, target, weight=self.weight,
--> 942                                ignore_index=self.ignore_index, reduction=self.reduction)
    943 
    944 

~/anaconda3/lib/python3.6/site-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2054     if size_average is not None or reduce is not None:
   2055         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2056     return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
   2057 
   2058 

~/anaconda3/lib/python3.6/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   1879         if target.size()[1:] != input.size()[2:]:
   1880             raise ValueError('Expected target size {}, got {}'.format(
-> 1881                 out_size, target.size()))
   1882         input = input.contiguous().view(n, c, 1, -1)
   1883         target = target.contiguous().view(n, 1, -1)

ValueError: Expected target size (64, 6), got torch.Size([64, 124])

I have outputs shape torch.Size([64, 124, 6]) and label shape: torch.Size([64, 124]). It seems that loss function want me to have outputs shape torch.Size([64, 6, 124]). I don’t get the reason why it is like this, and can some one tell me how to modify it?

nn.CrossEntropy expects a model output of the shape [batch_size, nb_classes, *additional_dimensions] and a target in [batch_size, *additional_dimensions] containing the class indices in the range [0, nb_classes] as explained in the docs.

For an output of [batch_size=64, nb_classes=124, additional=6], the target should have the shape [64, 6] and contain values in [0, 123].

Could you explain your use case a bit and what the dimensions in your tensors mean?

Thanks for the reply.

All commented dimensions are results by calling tensor.shape.

I’ll further explain the specific numbers

# print(model) result
NERModel(
  (embedding): Embedding(30290, 256)
  (LSTM): LSTM(256, 256, batch_first=True)
  (dense1): Linear(in_features=256, out_features=256, bias=True)
  (dense2): Linear(in_features=256, out_features=6, bias=True)
)

batch_size = 64
length_of_each_input = 124 words

As forward() function being invoked, I didn’t modify any intermediate dimensions and the output by my neural network was [64, 124, 6]. 64 is the batch size, and for each sentence with 124 words, 6 log_softmax results were calculated for each word.

It seems that I need to make it [64, 6, 124] so that I could call the loss function?

Yes, you should permute the output to be able to call nn.CrossEntropyLoss.
I’m not sure how your forward is implemented, but I would recommend to check the shapes for all intermediate tensors.

I did output every intermediate tensor shape in forward() function, and they were printed when training. To make the code easier to read, I replaced those command with the actual results.

def forward(self, x):
        """
        Args:
            - x: Input tensor of shape (batch_size, sequence_length)
        Return:
            Tensor of shape (batch_size, sequence_length, num_labels)
        """
        x = self.embedding(x)
        # after embedding: torch.Size([64, 124, 256])
        x, _ = self.LSTM(x)
        # after lstm: torch.Size([64, 124, 256])
        x = self.dense1(x)
        # after linear 1: torch.Size([64, 124, 256])
        x = self.dense2(x)
        # after linear 2: torch.Size([64, 124, 6])
        output = F.log_softmax(x, dim=1)
        # after softmax: torch.Size([64, 124, 6])
        return output

As you can see, I didn’t modify the dimensions at all. Is it common that I usually need to permute the columns in order to call loss function?

The shapes look alright, if you add the discussed permute.
However, the F.log_softmax operation should be applied in the class dimension, so in your case dim2 (or permute the tensor before the operation).
Also, since you are using F.log_softmax, you should use nn.NLLLoss as the criterion.
nn.CrossEntropyLoss will apply F.log_softmax and nn.NLLLoss internally.

Thanks! I tried permute just now and the error was eliminated. But if I understand correctly, you meant that I need to do the permutation before log_softmax called, am I correct?

As for loss function, if I’m using nn.CrossEntropyLoss, do you meant that there’s no need to apply F.log_softmax in forward() function?

If you want to keep F.log_softmax(x, dim=1), then yes.
Otherwise use dim=2, if you want to permute the tensor afterwards.

Yes, since it will be applied twice at the moment (in your forward and inside nn.CrossEntropyLoss).

1 Like

On a side note: You might want to add a non-linearity such as ReLU between your two linear layers. Otherwise, you don’t gain much from having two linear layers.

1 Like