I am using pytorch on windows and am getting an error with a simple model using LSTM. The block of code below works
batch = 4
seq_length = 50
feature_dim = 28
hidden_size = 20
num_layers = 1
i = Variable(torch.randn(batch, seq_length, feature_dim))
rnn = nn.LSTM(feature_dim, hidden_size, num_layers, batch_first=True)
h0 = Variable(torch.randn(num_layers, batch, hidden_size))
c0 = Variable(torch.randn(num_layers, batch, hidden_size))
output, hn = rnn(i, (h0, c0))
But not this when i put the lstm in a class
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
batch_size = 4
feature_dim = 50
hidden_size = 20
num_layers = 1
self.lstm = nn.LSTM(feature_dim, hidden_size=hidden_size, num_layers=1, batch_first=True)
self.h0 = Variable(torch.randn(1, batch_size, hidden_size))
self.c0 = Variable(torch.randn(1, batch_size, hidden_size))
def forward(self, x, mode=False):
output, hn = self.lstm(x, (self.h0,self.c0))
return output
i = Variable(torch.randn(batch, seq_length, feature_dim))
model = Net()
output = model(i)
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-8-56028679983b> in <module>()
1 model = Net()
----> 2 output = model(tr_input.narrow(0, 0, 4), True)
~\Anaconda3\envs\dl\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
355 result = self._slow_forward(*input, **kwargs)
356 else:
--> 357 result = self.forward(*input, **kwargs)
358 for hook in self._forward_hooks.values():
359 hook_result = hook(self, input, result)
<ipython-input-7-a19079c8e4fc> in forward(self, x, mode)
14 def forward(self, x, mode=False):
15
---> 16 output, hn = self.lstm(x, (self.h0,self.c0))
17 print(np.shape(output[-1]))
18
~\Anaconda3\envs\dl\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
355 result = self._slow_forward(*input, **kwargs)
356 else:
--> 357 result = self.forward(*input, **kwargs)
358 for hook in self._forward_hooks.values():
359 hook_result = hook(self, input, result)
~\Anaconda3\envs\dl\lib\site-packages\torch\nn\modules\rnn.py in forward(self, input, hx)
188 flat_weight = None
189
--> 190 self.check_forward_args(input, hx, batch_sizes)
191 func = self._backend.RNN(
192 self.mode,
~\Anaconda3\envs\dl\lib\site-packages\torch\nn\modules\rnn.py in check_forward_args(self, input, hidden, batch_sizes)
139 raise RuntimeError(
140 'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
--> 141 fn.input_size, input.size(-1)))
142
143 if is_input_packed:
NameError: name 'fn' is not defined