I have a python code segment related to a deep RL algorithm where it calculates the second order optimization and second derivative with Hessian matrix and fisher information matrix. Normally I run on GPU (cuda), but since I got a computational issue to calculate second derivative in cuda,
NotImplementedError: the derivative for '_cudnn_rnn_backward' is not implemented. Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API. To run double backwards, please disable the CuDNN backend temporarily while running the forward pass of your RNN. For example:
with torch.backends.cudnn.flags(enabled=False):
output = model(inputs)
So, I had to move to cpu for this code segment:
grads = torch.autograd.grad(policy_loss, self.policy.Actor.parameters(), retain_graph=True)
loss_grad = torch.cat([grad.view(-1) for grad in grads])
def Fvp_fim(v = -loss_grad):
with torch.backends.cudnn.flags(enabled=False):
M, mu, info = self.policy.Actor.get_fim(states_batch)
#pdb.set_trace()
mu = mu.view(-1)
filter_input_ids = set([info['std_id']])
t = torch.ones(mu.size(), requires_grad=True, device=mu.device)
mu_t = (mu * t).sum()
Jt = compute_flat_grad(mu_t, self.policy.Actor.parameters(), filter_input_ids=filter_input_ids, create_graph=True)
Jtv = (Jt * v).sum()
Jv = torch.autograd.grad(Jtv, t)[0]
MJv = M * Jv.detach()
mu_MJv = (MJv * mu).sum()
JTMJv = compute_flat_grad(mu_MJv, self.policy.Actor.parameters(), filter_input_ids=filter_input_ids, create_graph=True).detach()
JTMJv /= states_batch.shape[0]
std_index = info['std_index']
JTMJv[std_index: std_index + M.shape[0]] += 2 * v[std_index: std_index + M.shape[0]]
return JTMJv + v * self.damping
Above is the main function, where it calculates the second derivative. below are the supportive functions and relevant classes it has used.
def compute_flat_grad(output, inputs, filter_input_ids=set(), retain_graph=True, create_graph=False):
if create_graph:
retain_graph = True
inputs = list(inputs)
params = []
for i, param in enumerate(inputs):
if i not in filter_input_ids:
params.append(param)
grads = torch.autograd.grad(output, params, retain_graph=retain_graph, create_graph=create_graph, allow_unused=True)
j = 0
out_grads = []
for i, param in enumerate(inputs):
if (i in filter_input_ids):
out_grads.append(torch.zeros(param.view(-1).shape, device=param.device, dtype=param.dtype))
else:
if (grads[j] == None):
out_grads.append(torch.zeros(param.view(-1).shape, device=param.device, dtype=param.dtype))
else:
out_grads.append(grads[j].view(-1))
j += 1
grads = torch.cat(out_grads)
for param in params:
param.grad = None
return grads
------
self.policy.Actor.get_fim is get by:
import torch
import torch.nn as nn
from agents.models.feature_extracter import LSTMFeatureExtractor
from agents.models.policy import PolicyModule
from agents.models.value import ValueModule
class ActorNetwork(nn.Module):
def __init__(self, args):
super(ActorNetwork, self).__init__()
self.FeatureExtractor = LSTMFeatureExtractor(args)
self.PolicyModule = PolicyModule(args)
def forward(self, s):
lstmOut = self.FeatureExtractor.forward(s)
mu, sigma, action, log_prob = self.PolicyModule.forward(lstmOut)
return mu, sigma, action, log_prob
def get_fim(self, x):
mu, sigma, _, _ = self.forward(x)
if sigma.dim() == 1:
sigma = sigma.unsqueeze(0)
cov_inv = sigma.pow(-2).repeat(x.size(0), 1)
param_count = 0
std_index = 0
id = 0
std_id = id
for name, param in self.named_parameters():
if name == "sigma.weight":
std_id = id
std_index = param_count
param_count += param.view(-1).shape[0]
id += 1
return cov_inv.detach(), mu, {'std_id': std_id, 'std_index': std_index}
torch.backends.cudnn.flags(enabled=False):
syntax avoids cuda in the relevant function, anyway I need to run this code continuously on gpu instead of moving into cpu, only for this part. Is there any specific library or package I can use to calculate the second derivative calculated in Fvp_fim
that compatible with pytorch