I am trying to compute input gradients of a network in such a way that effectively the Jacobian of the network is detached before its multiplied with the grad_outputs. The below code shows what I’m trying to achieve but is horribly inefficient due to the loop. Here
instance_losses is a vector that contains the loss per instance in the batch.
params = self.get_params(model) instance_loss_grads = torch.autograd.grad(loss, [instance_losses], create_graph=True) # dims: [batch] # EXPENSIVE OPERATION: Compute Jacobian wrt training instances. instance_param_grads = [torch.autograd.grad(instance_loss, params, retain_graph=True) for instance_loss in instance_losses] # type: list(list(params)) param_grads =  for i in range(len(params)): # for each param group # Create Jacobian of param group wrt training instances # Note the `gs[i].detach()`. This is needed and is what makes the problem complex. group_jacobian = torch.stack([gs[i].detach() for gs in instance_param_grads], dim=0) # dims: [batch, ...] # Multiply differentiable `instance_loss_grads` and constant Jacobian. param_grads.append(torch.einsum('i,ij...->j...', instance_loss_grads, group_jacobian))
I see two potential options:
- Find a way to tell
torch.autograd.gradthat all operations during backprop must be performed after detaching.
- Find a way to perform efficient computation of Jacobian. This has been discussed many times over PyTorch forums with no general solution.
Any suggestions on how this can be optimized?