Computation graph breaks when using torch.nn.utils.stateless.functional_call with nn.LSTM

When I use torch.nn.utils.stateless.functional_call with a nn.Linear module and compute gradients with respect to the parameters passed in the functional_call, everything works fine.
But when I use the nn.LSTM module and then compute gradient, the computation graph breaks giving a runtime error.
Here is the reproducible code:

import torch
import torch.nn as nn
from torch import autograd

def linear():
    linear = nn.Linear(2,3)
    inp = torch.randn((2,2))
    params = {}
    for k, v in linear.named_parameters():
        params[k] = torch.randn_like(v, requires_grad=True)
    output = nn.utils.stateless.functional_call(linear, params, inp)
    loss = output.sum()
    delta = autograd.grad(loss, [params[k] for k in params], torch.ones_like(loss), create_graph=True)
    print('linear successful')

def lstm():
    lstm = nn.LSTM(3, 2, batch_first=True)
    inp = torch.randn((1, 5, 3))
    params = {}
    for k, v in lstm.named_parameters():
        params[k] = torch.randn_like(v, requires_grad=True)
    output, _ = nn.utils.stateless.functional_call(lstm, params, inp)
    loss = output.sum()
    delta = autograd.grad(loss, [params[k] for k in params], torch.ones_like(loss), create_graph=True)
    print('lstm successful')

linear()
lstm()

Output:

linear successful
Traceback (most recent call last):
  File "/home/tathagat/lm/ehr-meta-learning/rough.py", line 28, in <module>
    lstm()
  File "/home/tathagat/lm/ehr-meta-learning/rough.py", line 24, in lstm
    delta = autograd.grad(loss, [params[k] for k in params], torch.ones_like(loss), create_graph=True)
  File "/home/tathagat/anaconda3/envs/ehr/lib/python3.9/site-packages/torch/autograd/__init__.py", line 276, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

To debug further, I modified the lstm() function as follows:

def lstm():
    lstm = nn.LSTM(3, 2, batch_first=True)
    inp = torch.randn((1, 5, 3))
    params = {}
    for k, v in lstm.named_parameters():
        params[k] = torch.randn_like(v, requires_grad=True)
    output, _ = nn.utils.stateless.functional_call(lstm, params, inp)
    loss = output.sum()
    # delta = autograd.grad(loss, [params[k] for k in params], torch.ones_like(loss), create_graph=True)
    delta = autograd.grad(loss, [params[k] for k in params], torch.ones_like(loss), create_graph=True, allow_unused=True)
    none_grad_params = []
    for d, k in zip(delta, list(params.keys())):
        if d is None:
            none_grad_params.append(k)
    assert len(none_grad_params) == len(list(lstm.named_parameters()))
    print('lstm successful')

Now on calling lstm(), I get the output lstm successful, which indicates that the gradient for all parameters was None.

@albanD @ptrblck Can you please help with this? I read your posts mentioning that torch.nn.utils.stateless.functional_call should work these cases, but it doesn’t work here with LSTMs.