How to properly setup parameters when using functional_call and nn.LSTM

I have the following model:

class Encoder(nn.Module):                                                                                                                                 
      def __init__(self, action_dim, z_dim, skill_length):                                                                                                  
          super().__init__()                                                                                                                                
                                                                                                                                             
          self.lstm = nn.LSTM(action_dim, z_dim, skill_length, batch_first=True)                                                                                                             
                                                                                                                                                                                                                                                               
          self.log_std = nn.Parameter(torch.Tensor(z_dim))                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                    
      def forward(self, skill):                                                                                                                             
          mean, _ = self.lstm(skill)                                                                                                                                                                                                                                                    
          mean = mean[:, -1, :]                                                                                                                             
          std = torch.exp(torch.clamp(self.log_std, min=math.log(epsilon)))                                                                                 
          density = Normal(mean, std)                                                                                                                       
          sample = density.rsample()                                                                                                                        
                                                                                                                                                                                                 
          return sample, density  

Then, I have the following code to extract and initialize parameters:

def pars(model):
    params = {}
    for name, param in model.named_parameters():
        if 'std' in name:
            init = torch.nn.init.constant_(param, 0)
        else:
            init = torch.nn.init.orthogonal_(param)
        params[name] = nn.Parameter(init)
    return params

Then, I’d like to do this:

model = Encoder(6, 2, 10)
x = torch.rand(25, 10, 6)
params = pars(model)

samp, d = functional_call(model, params, x)

grad = autograd.grad(torch.mean(samp), params.values(), retain_graph=True, allow_unused=True)

There should be derivative depending on the lstm layer as well as the std layer, but in my case I get None for the lstm layer.

When I run

samp, d = model(x)
grad = autograd.grad(torch.mean(samp), model.parameters(), retain_graph=True)

then I get the correct gradient, i.e.g, it depends on the lstm layer as well as the std parameters, instead of getting None.

I am not sure what the problem is. I have computed gradients successfully before using functional_call, but it seems introducing rsample is causing some issues. I’d appreciate any help.

Edit: I found the problem is with the LSTM layer. When I use a linear layer, instead of LSTM, then everything works fine.