Expected input batch_size (1) to match target batch_size (10000)

I’m getting the error ValueError: Expected input batch_size (1) to match target batch_size (10000). for the following code:

class LogLinearLM(nn.Module):
    def __init__(self, vocab_size):
        super(LogLinearLM, self).__init__()
        self.linear = nn.Linear(2*vocab_size, vocab_size)
        self.softmax = nn.LogSoftmax(dim=1)
    def forward(self, input):
        out = self.linear(input)
        return self.softmax(out)

losses = []
loss_fn = nn.NLLLoss()
model = LogLinearLM(10000)
optimizer = optim.SGD(model.parameters(), lr=0.001)
for _ in range(30):
    total_loss = 0
    for w1, w2, w3 in nltk.trigrams(train):
        X = torch.cat([one_hot(w1), one_hot(w2)])
        yhat = model(X.view(1, -1))
        y = torch.tensor(one_hot(w3), dtype=torch.long)
        loss = loss_fn(yhat, y)
        total_loss += loss

The function one_hot is:

def one_hot(word):
    v = torch.zeros(10000)
    v[word2idx[word]] = 1
    return v

Can anyone help me, please? Thanks :slight_smile:

Are you working on a multi-label classification task, i.e. is your target holding more then one valid class?
If so, I think you should try using nn.BCELoss and unsqueeze your target at dim0 using y = y.unsqueeze(0).
However, if that’s not the case and you are dealing with a multi-class classification, i.e. your target only stores one valid target class, you should use the class index instead of the one-hot encoded target.
You can get the class index using y = torch.argmax(y).

The task is training a trigram model, so given two one-hot vectors (representing two words), I concatenate them, and the expected label is a one-hot vector representing the target word. I’m only expecting one label, it’s just one-hot encoded, and from what I know, I want to compare the softmax output to the target, which is what (I think) I’m doing…

OK, in that case the second approach would be valid.
If you have one valid class for each sample, your target should have the shape [batch_size] storing the class index. E.g. if the current word would be class5, you shouldn’t store it as [[0, 0, 0, 0, 0, 1, 0, ...]], but rather just use the class index torch.tensor([5]).
As described in the other post, you can achieve this using torch.argmax.
The docs have some more examples.

Let me know, if that works.

I think it worked, the error is gone! Thank you so much! :slight_smile:

1 Like