I am trying to compute input gradients of a network in such a way that effectively the Jacobian of the network is detached before its multiplied with the grad_outputs. The below code shows what I’m trying to achieve but is horribly inefficient due to the loop. Here instance_losses
is a vector that contains the loss per instance in the batch.
params = self.get_params(model)
instance_loss_grads = torch.autograd.grad(loss, [instance_losses], create_graph=True)[0] # dims: [batch]
# EXPENSIVE OPERATION: Compute Jacobian wrt training instances.
instance_param_grads = [torch.autograd.grad(instance_loss, params, retain_graph=True) for instance_loss in instance_losses] # type: list(list(params))
param_grads = []
for i in range(len(params)): # for each param group
# Create Jacobian of param group wrt training instances
# Note the `gs[i].detach()`. This is needed and is what makes the problem complex.
group_jacobian = torch.stack([gs[i].detach() for gs in instance_param_grads], dim=0) # dims: [batch, ...]
# Multiply differentiable `instance_loss_grads` and constant Jacobian.
param_grads.append(torch.einsum('i,ij...->j...', instance_loss_grads, group_jacobian))
I see two potential options:
- Find a way to tell
torch.autograd.grad
that all operations during backprop must be performed after detaching. - Find a way to perform efficient computation of Jacobian. This has been discussed many times over PyTorch forums with no general solution.
Any suggestions on how this can be optimized?