Question about utilizing vmap to calculate per-sample gradient

Hi, there,
I face a problem when I use vmap to calculate the batch of per-sample gradient of a RNN model. It always raise the dimension mismatch error and batching rule not implemented error. Is there any one can help me solve this?
Here is my codes:

import torch
import torch.nn as nn
from torch.func import functional_call, vmap, grad

class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        # batch_first=True means input/output is (batch, sequence, features)
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # x: (sequence_length, input_size) for a single sequence
        # x: (batch_size, sequence_length, input_size) for a batch
        h0 = torch.zeros(1, x.size(0), self.hidden_size).to(x.device)
        out, _ = self.rnn(x, h0)
        output = self.fc(out[:, -1, :])
        return output

def compute_loss_for_single_sequence(model, single_input_sequence):
    input_with_batch = single_input_sequence.unsqueeze(0)
    loss = functional_call(model, dict(model.named_parameters()), input_with_batch)
    return loss

if __name__ == '__main__':


    # 1. Prepare batched data and model
    input_size = 10
    hidden_size = 20
    output_size = 1
    batch_size = 1000
    sequence_length = 5
    feature_dim = input_size
    model = SimpleRNN(input_size, hidden_size, output_size)
    params = {name: p for name, p in model.named_parameters()}
    # Input data: (batch_size, sequence_length, feature_dim)
    batch_input = torch.randn(batch_size, sequence_length, feature_dim, requires_grad=True)

    # 2. test batch evaluate the forward
    batch_output = vmap(compute_loss_for_single_sequence, in_dims=(None, 0))(model, batch_input)
    
    # 3. test batch per-sample gradients
    per_example_grad_fn = vmap(grad(compute_loss_for_single_sequence), in_dims=(None, 0))(model, batch_input)
  1. when test batch evaluate the forward, it raises the error:
    RuntimeError: Batching rule not implemented for aten::rnn_tanh.input. We could not generate a fallback.
  2. when test batch per-sample gradients, it raises the error:
    I cannot correctly use vmap to evaluate function “compute_loss_for_single_sequence”, it obviously raises the error when test batch per-sample gradients.
    ValueError: Thing passed to transform API must be Tensor, got <class ‘main.SimpleRNN’>