Seeking a compatible library / package to calculate second derivative using gpu and PyTorch

I have a python code segment related to a deep RL algorithm where it calculates the second order optimization and second derivative with Hessian matrix and fisher information matrix. Normally I run on GPU (cuda), but since I got a computational issue to calculate second derivative in cuda,

NotImplementedError: the derivative for '_cudnn_rnn_backward' is not implemented. Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API. To run double backwards, please disable the CuDNN backend temporarily while running the forward pass of your RNN. For example: 
with torch.backends.cudnn.flags(enabled=False):
    output = model(inputs)

So, I had to move to cpu for this code segment:

grads = torch.autograd.grad(policy_loss, self.policy.Actor.parameters(), retain_graph=True)
loss_grad = torch.cat([grad.view(-1) for grad in grads])

def Fvp_fim(v = -loss_grad):
    with torch.backends.cudnn.flags(enabled=False):
        M, mu, info = self.policy.Actor.get_fim(states_batch)
        #pdb.set_trace()
        mu = mu.view(-1)
        filter_input_ids = set([info['std_id']])

        t = torch.ones(mu.size(), requires_grad=True, device=mu.device)
        mu_t = (mu * t).sum()
        Jt = compute_flat_grad(mu_t, self.policy.Actor.parameters(), filter_input_ids=filter_input_ids, create_graph=True)
        Jtv = (Jt * v).sum()
        Jv = torch.autograd.grad(Jtv, t)[0]
        MJv = M * Jv.detach()
        mu_MJv = (MJv * mu).sum()
        JTMJv = compute_flat_grad(mu_MJv, self.policy.Actor.parameters(), filter_input_ids=filter_input_ids, create_graph=True).detach()
        JTMJv /= states_batch.shape[0]
        std_index = info['std_index']
        JTMJv[std_index: std_index + M.shape[0]] += 2 * v[std_index: std_index + M.shape[0]]
        return JTMJv + v * self.damping

Above is the main function, where it calculates the second derivative. below are the supportive functions and relevant classes it has used.

def compute_flat_grad(output, inputs, filter_input_ids=set(), retain_graph=True, create_graph=False):
    if create_graph:
        retain_graph = True

    inputs = list(inputs)
    params = []
    for i, param in enumerate(inputs):
        if i not in filter_input_ids:
            params.append(param)

    grads = torch.autograd.grad(output, params, retain_graph=retain_graph, create_graph=create_graph, allow_unused=True)

    j = 0
    out_grads = []
    for i, param in enumerate(inputs):
        if (i in filter_input_ids):
            out_grads.append(torch.zeros(param.view(-1).shape, device=param.device, dtype=param.dtype))
        else:
            if (grads[j] == None):
                out_grads.append(torch.zeros(param.view(-1).shape, device=param.device, dtype=param.dtype))
            else:
                out_grads.append(grads[j].view(-1))
            j += 1
    grads = torch.cat(out_grads)

    for param in params:
        param.grad = None
    return grads

------

self.policy.Actor.get_fim is get by:

import torch
import torch.nn as nn


from agents.models.feature_extracter import LSTMFeatureExtractor
from agents.models.policy import PolicyModule
from agents.models.value import ValueModule


class ActorNetwork(nn.Module):
    def __init__(self, args):
        super(ActorNetwork, self).__init__()
        self.FeatureExtractor = LSTMFeatureExtractor(args)
        self.PolicyModule = PolicyModule(args)

    def forward(self, s):
        lstmOut = self.FeatureExtractor.forward(s)
        mu, sigma, action, log_prob = self.PolicyModule.forward(lstmOut)
        return mu, sigma, action, log_prob
    
    def get_fim(self, x):
        mu, sigma, _, _ = self.forward(x)

        if sigma.dim() == 1:
            sigma = sigma.unsqueeze(0)

        cov_inv = sigma.pow(-2).repeat(x.size(0), 1)

        param_count = 0
        std_index = 0
        id = 0
        std_id = id
        for name, param in self.named_parameters():
            if name == "sigma.weight":
                std_id = id
                std_index = param_count
            param_count += param.view(-1).shape[0]
            id += 1

        return cov_inv.detach(), mu, {'std_id': std_id, 'std_index': std_index}

torch.backends.cudnn.flags(enabled=False): syntax avoids cuda in the relevant function, anyway I need to run this code continuously on gpu instead of moving into cpu, only for this part. Is there any specific library or package I can use to calculate the second derivative calculated in Fvp_fim that compatible with pytorch

This context manager disables cuDNN only for the block but still uses native CUDA kernels.

In the bigger picture there are large amounts of batches going through this function, since all of 'em have to go sequentially through this function, it highly increases the total running time.