CrossEntropyLoss ValueError: Expected target size (2, 13), got torch.Size([2])

I am trying to train a Pytorch LSTM network, but I’m getting ValueError: Expected target size (2, 13), got torch.Size([2]) when I try to calculate CrossEntropyLoss. I think I need to change the shape somewhere, but I can’t figure out where. I’ve seen this error message in some posts but couldn’t find any that use cross entropy loss.

For this problem, I am trying to predict a from word using the last three words. So, if “this is example text” appeared in the corpus, the corresponding feature would be [“this”, “is”,
“example”] and the label would be [“text”] (with every word mapped to their ids of course).

Here is my network definition:

class LSTM(nn.Module):

    def __init__(self, vocab_size, embedding_dim, hidden_dim, n_layers, drop_prob=0.2):
        super(LSTM, self).__init__()

        # network size parameters
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim


        # the layers of the network
        self.embedding = nn.Embedding(self.vocab_size, self.embedding_dim)
        self.lstm = nn.LSTM(self.embedding_dim, self.hidden_dim, self.n_layers, dropout=drop_prob, batch_first=True)
        self.dropout = nn.Dropout(drop_prob)
        self.fc = nn.Linear(self.hidden_dim, self.vocab_size)



    def forward(self, input, hidden):
        # Perform a forward pass of the model on some input and hidden state.
        batch_size = input.size(0)
        print(f'batch_size: {batch_size}')

        print(Input shape: {input.shape}')

        # pass through embeddings layer
        embeddings_out = self.embedding(input)
        print(f'Shape after Embedding: {embeddings_out.shape}')


        # pass through LSTM layers
        lstm_out, hidden = self.lstm(embeddings_out, hidden)
        print(f'Shape after LSTM: {lstm_out.shape}')


        # pass through dropout layer
        dropout_out = self.dropout(lstm_out)
        print(f'Shape after Dropout: {dropout_out.shape}')


        #pass through fully connected layer
        fc_out = self.fc(dropout_out)
        print(f'Shape after FC: {fc_out.shape}')

        # return output and hidden state
        return fc_out, hidden


    def init_hidden(self, batch_size):
        #Initializes hidden state
        # Create two new tensors `with sizes n_layers x batch_size x hidden_dim,
        # initialized to zero, for hidden state and cell state of LSTM


        hidden = (torch.zeros(self.n_layers, batch_size, self.hidden_dim), torch.zeros(self.n_layers, batch_size, self.hidden_dim))
        return hidden

I added comments stating the shape of the network at each spot. My data is in a TensorDataset called training_dataset with two attributes, features and labels. Features has shape torch.Size([97, 3]), and
labels has shape: torch.Size([97]).

This is the code for the network training:

# Size parameters
vocab_size = 13
embedding_dim = 256
hidden_dim = 256       
n_layers = 2     

# Training parameters
epochs = 3
learning_rate = 0.001
clip = 1
batch_size = 2


training_loader = DataLoader(training_dataset, batch_size=batch_size, drop_last=True, shuffle=True)

net = LSTM(vocab_size, embedding_dim, hidden_dim, n_layers)
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
loss_func = torch.nn.CrossEntropyLoss()

net.train()
for e in range(epochs):
    print(f'Epoch {e}')
    print(batch_size)
    hidden = net.init_hidden(batch_size)

    # loops through each batch
    for features, labels in training_loader:

        # resets training history
        hidden = tuple([each.data for each in hidden])
        net.zero_grad()

        # computes gradient of loss from backprop
        output, hidden = net.forward(features, hidden)
        loss = loss_func(output, labels)
        loss.backward()

        # using clipping to avoid exploding gradient
        nn.utils.clip_grad_norm_(net.parameters(), clip)
        optimizer.step()

When I try to run the training I get the following error:

Traceback (most recent call last):
  File "train.py", line 75, in <module>
    loss = loss_func(output, labels)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 947, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "/usr/local/lib/python3.8/site-packages/torch/nn/functional.py", line 2422, in cross_entropy
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
  File "/usr/local/lib/python3.8/site-packages/torch/nn/functional.py", line 2227, in nll_loss
    raise ValueError('Expected target size {}, got {}'.format(
ValueError: Expected target size (2, 13), got torch.Size([2])

Also here is the result of the print statements:

batch_size: 2
Input shape: torch.Size([2, 3])
Shape after Embedding: torch.Size([2, 3, 256])
Shape after LSTM: torch.Size([2, 3, 256])
Shape after Dropout: torch.Size([2, 3, 256])
Shape after FC: torch.Size([2, 3, 13])

There is some kind of shape error happening, but I can’t figure out where. Any help would be appreciated. If relevant I’m using Python 3.8.5 and Pytorch 1.6.0.

The output tensor of the nn.LSTM with batch_first=Trueis returned in the shape[batch_size, seq_len, features]`. Based on your description I guess you would like to use the activation of the last time step for the classification, so you might want to slice it via:

lstm_out, hidden = self.lstm(embeddings_out, hidden)
lstm_out = lstm_out[:, -1]

and process this tensor further.

2 Likes

Thats exactly what I was looking for, thanks!