NLLLoss expected target size for sequence tagger

I want to write a sequence tagger, or NER you could say, with an (Bi-)LSTM.
This is my code (provide the whole just fyi, but the most important part is the LSTMTagger class on the top and the training loop at the bottom:

import sys, os
from torchnlp.word_to_vector import FastText
from pathlib import Path
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pad_sequence
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F
import re
from torchtext import data
from tqdm import tqdm, trange
import fasttext
import fasttext.util

EMBEDDING_DIM = 300
HIDDEN_DIM = 6
learning_rate = 0.01
epochs = 5
batch_size = 1

tag_to_ix = {"O": 0, "B-SKILL": 1, "I-SKILL": 2}

cuda_device = 0
device = torch.device("cuda:%d" % cuda_device if torch.cuda.is_available() else "cpu")

base_path = Path("..")
fasttext_model = base_path / "models" / "cc.de.300.bin"

class LSTMTagger(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):
        super(LSTMTagger, self).__init__()
        self.hidden_dim = hidden_dim

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)

        self.lstm_forward = nn.LSTM(embedding_dim, hidden_dim,
                                   batch_first=True,
                                   bidirectional=False)
        
        self.lstm_backward = nn.LSTM(embedding_dim, hidden_dim,
                                    batch_first=True,
                                    bidirectional=False)

        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)

    def forward(self, embeds_forward, embeds_backward):

        lstm_out_forward, _ = self.lstm_forward(embeds_forward)
        # lstm_out_backward, _ = self.lstm_backward(embeds_backward)

        # lstm_concat = torch.cat((lstm_out_forward, lstm_out_backward), 2)
        
        tag_space = self.hidden2tag(lstm_out_forward)
        tag_scores = F.log_softmax(tag_space, dim=1)
        print(tag_scores.size()) # >> torch.Size([1, 2, 3])
        return tag_scores

class EmbeddingVectorizer:
    def __init__(self):

        self.embedding_model = fasttext.load_model(str(fasttext_model))

    def __call__(self, doc):
        """
        Convert address to embedding vectors
        :param address: The address to convert
        :return: The embeddings vectors
        """
        embeddings = []
        for word in doc:
            embeddings.append(self.embedding_model[word])
        return embeddings

embedding_model = EmbeddingVectorizer()

def pad_collate_fn(batch):
    sequences_vectors, sequences_labels, lengths = zip(*[
        (torch.FloatTensor(seq_vectors), torch.LongTensor(labels), len(seq_vectors))
        for (seq_vectors, labels) in sorted(batch, key=lambda x: len(x[0]), reverse=True)
    ])

    lengths = torch.LongTensor(lengths)

    padded_sequences_vectors = pad_sequence(sequences_vectors, batch_first=True, padding_value=0)

    padded_sequences_labels = pad_sequence(sequences_labels, batch_first=True, padding_value=-100)

    return (padded_sequences_vectors, lengths), padded_sequences_labels

def prepare_sequence(seq, to_ix):
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)

def read_conll(file_path):
    file_path = Path(file_path)

    raw_text = file_path.read_text().strip()
    raw_docs = re.split(r'\n\t?\n', raw_text)
    token_docs = []
    tag_docs = []
    for doc in raw_docs:
        tokens = []
        tags = []
        for line in doc.split('\n'):
            token, tag = line.split('\t')
            tokens.append(token)
            tags.append(tag)
        token_docs.append(tokens)
        tag_docs.append(tags)

    return token_docs, tag_docs

start_path = Path("..")
data_path = start_path / "data" / "traindata_toy.conll"
texts, tags = read_conll(data_path)

training_data = []
for doc, doc_tags in zip(texts, tags):
    training_data.append((doc, doc_tags))  # Example data: [(["a", "b", "c"], ["O", "O", "B-SKILL"]), ([...], [...])]

word_to_ix = {}
for sent, tags in training_data:
    for word in sent:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)
    
embeddings_tags = []    

for doc, doc_tags in training_data:
    embs = embedding_model(doc)
    targets = prepare_sequence(doc_tags, tag_to_ix)
    embeddings_tags.append((embs, targets))

train_size = int(0.8 * len(embeddings_tags))
valid_size = int((len(embeddings_tags) - train_size) / 2)
test_size = valid_size
train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(embeddings_tags, [train_size, valid_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=pad_collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, collate_fn=pad_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=pad_collate_fn)


model = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_ix), len(tag_to_ix))
loss_function = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)


for epoch in trange(epochs, desc="Epoch"):
    for step, batch in enumerate(train_loader):

        docs = batch[0][0].to(device)
        tags = batch[1].to(device)
        
        model.zero_grad()
        
        doc_forward = docs
        doc_backward = torch.clone(doc_forward)
        doc_backward = torch.flip(doc_backward, [1])

        print(doc_forward.size())  # >> torch.Size([1, 2, 300])

        tag_scores = model(doc_forward, doc_backward)

        print(tags.size()) # >> torch.Size([1, 2])
        print(tag_scores.size()) # >> torch.Size([1, 2, 3])
        loss = loss_function(tag_scores, tags)
        loss.backward()
        optimizer.step()

I get the error
RuntimeError: Expected target size [1, 3], got [1, 2]
at
loss = loss_function(tag_scores, tags).
But the sizes should be right for NLLLoss or not? You can see the sizes I got for a document with two tokens. So the size of the target tags is [1, 2] for one batch and two tokens, and the size of the predicted tags after log_softmax is [1, 2, 3] for three possible tags, as I understand it correctly. But I still get this error. Do you have any idea? Thank you very much!

No, the sizes are wrong as nn.NLLLoss expects the model output to have the shape [batch_size, nb_classes, seq_len] and a target of [batch_size, seq_len] for a temporal multi-class classification.
permute the model output and it should be working.

1 Like

Thank you, it worked. Good to know how to read the documentation right… :sweat_smile: