Currently, I have the following function implemented for getting the gradient norm for algorithms like SGA and Consensus Optimization:
def hamiltonian_loss(d_x: Tensor, d_z: Tensor, gan: GAN): """ Calculates the "hamiltonian loss" as described in "The Mechanics of n-Player Differentiable Games" (Balduzzi et al 2018), i.e. the loss function defined by 1/2 of the 2-norm of the overall gradient vector. @param d_x: Discriminator evaluated at real samples @param d_z: Discriminator evaluated at generated samples @param gan: GAN module @return: Hamiltonian loss """ # Get gradient g_loss_ = gan.generator_loss(d_z) d_loss_ = gan.discriminator_loss(d_x, d_z) g_grad = grad(g_loss_, gan.G.parameters(), retain_graph=True, create_graph=True) d_grad = grad(d_loss_, gan.D.parameters(), retain_graph=True, create_graph=True) gradient_vec = tr.cat([tr.flatten(w) for w in g_grad] + [tr.flatten(w) for w in d_grad]) norm = tr.pow(tr.linalg.norm(gradient_vec), 2) / 2 return norm
I know this involves a second backpropagation pass and is going to come with a performance hit, but currently evaluation is very, very slow (over a second per iteration).
I think the issue has to do with the list comprehension as it is the only part that is not pure pytorch code. I’ve looked at the documentation for the higher-level API beta, but it doesn’t seem like that’s what I’m looking for.
Is there a way to do this more efficiently?