Hi, I am using PyTorch 1.10 and encountered this issue. Here is my code:
class DecoderRNN(nn.Module):
def __init__(self, hidden_size, output_size, embed_dim, bs):
super(DecoderRNN, self).__init__()
self.hidden_size = hidden_size
self.bs = bs
self.embed_dim = embed_dim
self.embedding = nn.Embedding(output_size, embed_dim)
self.gru = nn.GRU(embed_dim, hidden_size)
self.linear = nn.Linear(2 * hidden_size, output_size)
self.softmax_0 = nn.LogSoftmax(dim=0)
self.softmax_1 = nn.LogSoftmax(dim=1)
self.hidden = None
def forward(self, input, hidden, encoder_outputs):
output = self.embedding(input) # 1 * bs * ed
output = F.relu(output) # 1 * bs * ed
# output: 1 * bs * hs
# hidden: 1 * bs * hs
output, hidden = self.gru(output, hidden)
weights = self.dot_product(output, encoder_outputs)
weights = self.softmax_0(weights)
attn_output = self.weighted_sum(weights, encoder_outputs) # bs * ed
output = torch.cat((output[0], attn_output), dim=1)
output = self.linear(output) # bs * os
output = self.softmax_1(output) # bs * os
return output, hidden
def weighted_sum(self, weights, encoder_outputs):
outputs = encoder_outputs.clone()
res = torch.tensor([], device=device)
for i in range(weights.size()[0]): # n
weight = weights[i] # bs
words = outputs[i] # bs * ed
for j in range(weight.size()[0]): # bs
words[j] = words[j] * weight[j]
res = torch.cat((res, words.unsqueeze(0)), dim=0) # n * bs * ed
return torch.sum(res, dim=0) # bs * ed
def dot_product(self, a, b):
# a: 1 * bs * hs
# b: n * bs * hs
res = torch.tensor([], device=device)
decoder_word = a.squeeze(0)
for i in range(b.size()[0]):
temp = torch.tensor([], device=device)
for j in range(b.size()[1]):
dot = torch.matmul(decoder_word[j], b[i][j]).unsqueeze(0) # HERE
# dot = torch.randn(1, device=device)
temp = torch.cat((temp, dot), dim=0)
res = torch.cat((res, temp.unsqueeze(0)), dim=0) # n * bs
return res
def initHidden(self):
return torch.zeros(1, self.bs, self.hidden_size, device=device)
I found that if I replace “dot = torch.matmul(decoder_word[j], b[i][j]).unsqueeze(0)” by “dot = torch.randn(1, device=device)”, then no errors will be reported. Not sure if this information is helpful. Could anyone give me some suggestions? Thanks!