I have a fairly simple network:
class Network(nn.Module):
def __init__(self, input_dim, lstm_layers=2, hidden_size=20, dropout=0, bidirectional=False):
super(Network, self).__init__()
self.lstm_layers = nn.LSTM(input_dim, hidden_size, num_layers=lstm_layers, batch_first=True, dropout=dropout, bidirectional=bidirectional)
self.output_lstm_layer = nn.LSTM(hidden_size, 1, num_layers=1, batch_first=True)
def forward(self, x):
x, _ = self.lstm_layers(x)
x, _ = self.output_lstm_layer(x)
return x
However, when I try and train it:
for i, (data, labels) in enumerate(train_dataloader):
data = data.cuda()
data[0].requires_grad_()
labels = labels.cuda()
labels[0].requires_grad_()
optim.zero_grad()
out, lengths = net(data)
print(out)
l = loss(out.view(-1), labels[0])
print(l)
l.backward() # <- Error Here: RuntimeError: CUDNN_STATUS_BAD_PARAM
optimizer.step()
print(l)
data
, and labels
out of train_dataloader
are both PackedSequences
. They are initialized with “requires_grad = False”, which is why I use the requires_grad_()
method on them. (with more thought, I’m not sure that this is needed).
Unfortunately, while a loss is computed, l.backward()
fails:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-55-c606886d4e12> in <module>()
11 l = loss(out.data.view(-1), labels.data)
12 print(l)
---> 13 l.backward()
14 optimizer.step()
15 print(l)
~/.virtualenvs/truu_walking_not_walking_lstm-JqLRVldl/lib/python3.6/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
91 products. Defaults to ``False``.
92 """
---> 93 torch.autograd.backward(self, gradient, retain_graph, create_graph)
94
95 def register_hook(self, hook):
~/.virtualenvs/truu_walking_not_walking_lstm-JqLRVldl/lib/python3.6/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
87 Variable._execution_engine.run_backward(
88 tensors, grad_tensors, retain_graph, create_graph,
---> 89 allow_unreachable=True) # allow_unreachable flag
90
91
RuntimeError: CUDNN_STATUS_BAD_PARAM
And I have no idea why.
Anyone have an idea?