Higher-order gradients w.r.t. different functions

When you say large gradients , do you have a specific norm in mind?

Currently, I am using the global l2-norm:

def global_norm(grads):
    norm = 0
    for grad in grads:
        norm += grad.norm(2)**2
    return norm**.5

Does the function f factor into “application of NNs” and “compute loss”?

When you say “factor” here, I’m assuming you don’t mean multiplicative factorization. Perhaps it will clarify to mention that I have settled on an intermediate solution that is not very performant. Essentially:

g_optimizer.zero_grad()
f_optimizer = optim.Adam(f.parameters())
g_optimizer = optim.Adam(g.parameters())
f_grad = torch.autograd.grad(compute_f_loss(), f.parameters())
global_norm(f_grad).backward()  # this assigns grads to both f.parameters() and g.parameters()
g_optimizer.step()  # this only steps g.parameters()

f_optimizer.zero_grad()  # this clears the gradients assigned to f.parameters()
compute_f_loss().backward()
f_optimizer.step()  # this only steps f.parameters()

There are two performance problems:

  1. It requires two backward passes, which is very slow. It would be great if I could reuse f_grad instead of recomputing it with compute_f_loss().backward().
  2. It requires two calls to compute_f_loss(). This might be remedied with retain_graph=True.