Efficient detached vector Jacobian product


(Sid MS) #1

I am trying to compute input gradients of a network in such a way that effectively the Jacobian of the network is detached before its multiplied with the grad_outputs. The below code shows what I’m trying to achieve but is horribly inefficient due to the loop. Here instance_losses is a vector that contains the loss per instance in the batch.

params = self.get_params(model)
instance_loss_grads = torch.autograd.grad(loss, [instance_losses], create_graph=True)[0]  # dims: [batch]

# EXPENSIVE OPERATION: Compute Jacobian wrt training instances.
instance_param_grads = [torch.autograd.grad(instance_loss, params, retain_graph=True) for instance_loss in instance_losses]  # type: list(list(params))

param_grads = []
for i in range(len(params)):  # for each param group
    # Create Jacobian of param group wrt training instances
    # Note the `gs[i].detach()`. This is needed and is what makes the problem complex.
    group_jacobian = torch.stack([gs[i].detach() for gs in instance_param_grads], dim=0)  # dims: [batch, ...]

    # Multiply differentiable `instance_loss_grads` and constant Jacobian. 
    param_grads.append(torch.einsum('i,ij...->j...', instance_loss_grads, group_jacobian))

I see two potential options:

  • Find a way to tell torch.autograd.grad that all operations during backprop must be performed after detaching.
  • Find a way to perform efficient computation of Jacobian. This has been discussed many times over PyTorch forums with no general solution.

Any suggestions on how this can be optimized?