When you say large gradients , do you have a specific norm in mind?
Currently, I am using the global l2-norm:
def global_norm(grads):
norm = 0
for grad in grads:
norm += grad.norm(2)**2
return norm**.5
Does the function f factor into “application of NNs” and “compute loss”?
When you say “factor” here, I’m assuming you don’t mean multiplicative factorization. Perhaps it will clarify to mention that I have settled on an intermediate solution that is not very performant. Essentially:
g_optimizer.zero_grad()
f_optimizer = optim.Adam(f.parameters())
g_optimizer = optim.Adam(g.parameters())
f_grad = torch.autograd.grad(compute_f_loss(), f.parameters())
global_norm(f_grad).backward() # this assigns grads to both f.parameters() and g.parameters()
g_optimizer.step() # this only steps g.parameters()
f_optimizer.zero_grad() # this clears the gradients assigned to f.parameters()
compute_f_loss().backward()
f_optimizer.step() # this only steps f.parameters()
There are two performance problems:
- It requires two backward passes, which is very slow. It would be great if I could reuse
f_grad
instead of recomputing it withcompute_f_loss().backward()
. - It requires two calls to
compute_f_loss()
. This might be remedied withretain_graph=True
.