Currently, I have the following function implemented for getting the gradient norm for algorithms like SGA and Consensus Optimization:
def hamiltonian_loss(d_x: Tensor, d_z: Tensor, gan: GAN):
"""
Calculates the "hamiltonian loss" as described in "The Mechanics of n-Player Differentiable Games" (Balduzzi et al
2018), i.e. the loss function defined by 1/2 of the 2-norm of the overall gradient vector.
@param d_x: Discriminator evaluated at real samples
@param d_z: Discriminator evaluated at generated samples
@param gan: GAN module
@return: Hamiltonian loss
"""
# Get gradient
g_loss_ = gan.generator_loss(d_z)
d_loss_ = gan.discriminator_loss(d_x, d_z)
g_grad = grad(g_loss_, gan.G.parameters(), retain_graph=True, create_graph=True)
d_grad = grad(d_loss_, gan.D.parameters(), retain_graph=True, create_graph=True)
gradient_vec = tr.cat([tr.flatten(w) for w in g_grad] + [tr.flatten(w) for w in d_grad])
norm = tr.pow(tr.linalg.norm(gradient_vec), 2) / 2
return norm
I know this involves a second backpropagation pass and is going to come with a performance hit, but currently evaluation is very, very slow (over a second per iteration).
I think the issue has to do with the list comprehension as it is the only part that is not pure pytorch code. I’ve looked at the documentation for the higher-level API beta, but it doesn’t seem like that’s what I’m looking for.
Is there a way to do this more efficiently?