I’m getting the error
ValueError: Expected input batch_size (1) to match target batch_size (10000). for the following code:
def __init__(self, vocab_size):
self.linear = nn.Linear(2*vocab_size, vocab_size)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, input):
out = self.linear(input)
losses = 
loss_fn = nn.NLLLoss()
model = LogLinearLM(10000)
optimizer = optim.SGD(model.parameters(), lr=0.001)
for _ in range(30):
total_loss = 0
for w1, w2, w3 in nltk.trigrams(train):
X = torch.cat([one_hot(w1), one_hot(w2)])
yhat = model(X.view(1, -1))
y = torch.tensor(one_hot(w3), dtype=torch.long)
loss = loss_fn(yhat, y)
total_loss += loss
v = torch.zeros(10000)
v[word2idx[word]] = 1
Can anyone help me, please? Thanks
Are you working on a multi-label classification task, i.e. is your target holding more then one valid class?
If so, I think you should try using
nn.BCELoss and unsqueeze your target at dim0 using
y = y.unsqueeze(0).
However, if that’s not the case and you are dealing with a multi-class classification, i.e. your target only stores one valid target class, you should use the class index instead of the one-hot encoded target.
You can get the class index using
y = torch.argmax(y).
The task is training a trigram model, so given two one-hot vectors (representing two words), I concatenate them, and the expected label is a one-hot vector representing the target word. I’m only expecting one label, it’s just one-hot encoded, and from what I know, I want to compare the softmax output to the target, which is what (I think) I’m doing…
OK, in that case the second approach would be valid.
If you have one valid class for each sample, your target should have the shape
[batch_size] storing the class index. E.g. if the current word would be class5, you shouldn’t store it as
[[0, 0, 0, 0, 0, 1, 0, ...]], but rather just use the class index
As described in the other post, you can achieve this using
The docs have some more examples.
Let me know, if that works.
I think it worked, the error is gone! Thank you so much!