When I use torch.nn.utils.stateless.functional_call with a nn.Linear module and compute gradients with respect to the parameters passed in the functional_call, everything works fine.
But when I use the nn.LSTM module and then compute gradient, the computation graph breaks giving a runtime error.
Here is the reproducible code:
import torch
import torch.nn as nn
from torch import autograd
def linear():
linear = nn.Linear(2,3)
inp = torch.randn((2,2))
params = {}
for k, v in linear.named_parameters():
params[k] = torch.randn_like(v, requires_grad=True)
output = nn.utils.stateless.functional_call(linear, params, inp)
loss = output.sum()
delta = autograd.grad(loss, [params[k] for k in params], torch.ones_like(loss), create_graph=True)
print('linear successful')
def lstm():
lstm = nn.LSTM(3, 2, batch_first=True)
inp = torch.randn((1, 5, 3))
params = {}
for k, v in lstm.named_parameters():
params[k] = torch.randn_like(v, requires_grad=True)
output, _ = nn.utils.stateless.functional_call(lstm, params, inp)
loss = output.sum()
delta = autograd.grad(loss, [params[k] for k in params], torch.ones_like(loss), create_graph=True)
print('lstm successful')
linear()
lstm()
Output:
linear successful
Traceback (most recent call last):
File "/home/tathagat/lm/ehr-meta-learning/rough.py", line 28, in <module>
lstm()
File "/home/tathagat/lm/ehr-meta-learning/rough.py", line 24, in lstm
delta = autograd.grad(loss, [params[k] for k in params], torch.ones_like(loss), create_graph=True)
File "/home/tathagat/anaconda3/envs/ehr/lib/python3.9/site-packages/torch/autograd/__init__.py", line 276, in grad
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.
To debug further, I modified the lstm()
function as follows:
def lstm():
lstm = nn.LSTM(3, 2, batch_first=True)
inp = torch.randn((1, 5, 3))
params = {}
for k, v in lstm.named_parameters():
params[k] = torch.randn_like(v, requires_grad=True)
output, _ = nn.utils.stateless.functional_call(lstm, params, inp)
loss = output.sum()
# delta = autograd.grad(loss, [params[k] for k in params], torch.ones_like(loss), create_graph=True)
delta = autograd.grad(loss, [params[k] for k in params], torch.ones_like(loss), create_graph=True, allow_unused=True)
none_grad_params = []
for d, k in zip(delta, list(params.keys())):
if d is None:
none_grad_params.append(k)
assert len(none_grad_params) == len(list(lstm.named_parameters()))
print('lstm successful')
Now on calling lstm()
, I get the output lstm successful
, which indicates that the gradient for all parameters was None.
@albanD @ptrblck Can you please help with this? I read your posts mentioning that torch.nn.utils.stateless.functional_call should work these cases, but it doesn’t work here with LSTMs.