recently i started learning PyTorch, so i have a model that takes an input and guess the next word
class LSTM(nn.Module):
def __init__(self, hidden_size, embedding_size, vocab_size):
super(LSTM, self).__init__()
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.embedding = nn.Embedding(vocab_size, embedding_size)
self.lstm = nn.LSTM(embedding_size, hidden_size)
self.out = nn.Linear(hidden_size, vocab_size)
self.dropout = nn.Dropout(0.1)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, sentence):
embeds = self.embedding(sentence)
lstm_out, _ = self.lstm(embeds)
output = self.out(output)
output = self.dropout(output)
out = self.softmax(output)
return out
batch_size = 20
input_size, output size = 32
vocab_size = 9882 #generated dynamically in original code
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, drop_last=True)
loss_function = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.1)
def train(epoch):
model.train()
current_loss = 0
for x_batch, y_batch in dataloader:
model.zero_grad()
pred = model(x_batch)
loss = loss_function(pred, y_batch)
current_loss += loss
loss.backward()
optimizer.step()
stacktrace
ValueError Traceback (most recent call last)
<ipython-input-544-31519a92b1e7> in <module>()
2
3 for iter in range(1, n_iters + 1):
----> 4 output, loss = train(iter)
5 current_loss += loss
6 print_loss_total += loss
3 frames
<ipython-input-541-a72868643ddf> in train(epoch)
15 pred = model(x_batch)
16
---> 17 loss = loss_function(pred, y_batch)
18 print(loss, loss, 'loss')
19 current_loss += loss
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
545 result = self._slow_forward(*input, **kwargs)
546 else:
--> 547 result = self.forward(*input, **kwargs)
548 for hook in self._forward_hooks.values():
549 hook_result = hook(self, input, result)
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py in forward(self, input, target)
202
203 def forward(self, input, target):
--> 204 return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction)
205
206
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
1832 if target.size()[1:] != input.size()[2:]:
1833 raise ValueError('Expected target size {}, got {}'.format(
-> 1834 out_size, target.size()))
1835 input = input.contiguous().view(n, c, 1, -1)
1836 target = target.contiguous().view(n, 1, -1)
ValueError: Expected target size (20, 9882), got torch.Size([20, 32])
I have a gut feeling this might be related to topk but i really have no clue what could be wrong.
Thanks and any help will be appreciated.