Hi everyone,
After two hours of debugging, I still can’t find the reason for the error I’m getting, ValueError: Expected target size (128, 10000), got torch.Size([128, 1])
I thought the code was pretty straightforward, and even resembles one of the tutorials, but still, an error.
The code is
class FeedForwardLM(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_size):
super(FeedForwardLM, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.fc1 = nn.Linear(2 * embedding_dim, hidden_size)
self.fc2 = nn.Linear(hidden_size, vocab_size)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, w1, w2):
mw1 = self.embedding(w1)
mw2 = self.embedding(w2)
m = torch.cat([mw1, mw2], dim=2)
out = torch.tanh(self.fc1(m))
out = self.fc2(out)
out = self.softmax(out)
return out
dataset = IndexLMDataset(train)
data_loader = data.DataLoader(dataset, batch_size=128, shuffle=True)
losses = []
loss_fn = nn.NLLLoss()
model = FeedForwardLM(10000, 300, 512).to(device)
optimizer = optim.Adam(model.parameters())
for _ in range(100):
total_loss = 0
for batch in data_loader:
(w1,w2), y = batch
model.zero_grad()
yhat = model(w1, w2)
loss = loss_fn(yhat, y)
loss.backward()
optimizer.step()
total_loss += loss
losses.append(total_loss)
torch.save(model.state_dict(), "feedforward.model")
plt.plot(losses);
and the code for the dataset is:
class IndexLMDataset(data.Dataset):
def __init__(self, train):
self._trigrams = list(nltk.trigrams(train))[:1000]
def __len__(self):
return len(self._trigrams)
def __getitem__(self, index):
w1, w2, w3 = self._trigrams[index]
w1 = torch.tensor([word2idx.get(w1, word2idx['UNK'])], dtype=torch.long).to(device)
w2 = torch.tensor([word2idx.get(w2, word2idx['UNK'])], dtype=torch.long).to(device)
w3 = torch.tensor([word2idx.get(w3, word2idx['UNK'])], dtype=torch.long).to(device)
return (w1, w2), w3
where train
is just a list of words, and nltk.trigrams(train)
returns triplets of words, so nltk.trigrams(['The', 'quick', 'brown', 'fox'])
returns [('The', 'quick', 'brown'), ('quick', 'brown', 'fox')]
.
Any help is appreciated, as always!