Efficient way to calculate gradient norms/gradient penalty

Currently, I have the following function implemented for getting the gradient norm for algorithms like SGA and Consensus Optimization:

def hamiltonian_loss(d_x: Tensor, d_z: Tensor, gan: GAN):
    Calculates the "hamiltonian loss" as described in "The Mechanics of n-Player Differentiable Games" (Balduzzi et al
     2018), i.e. the loss function defined by 1/2 of the 2-norm of the overall gradient vector.

    @param d_x: Discriminator evaluated at real samples
    @param d_z: Discriminator evaluated at generated samples
    @param gan: GAN module
    @return: Hamiltonian loss

    # Get gradient
    g_loss_ = gan.generator_loss(d_z)
    d_loss_ = gan.discriminator_loss(d_x, d_z)

    g_grad = grad(g_loss_, gan.G.parameters(), retain_graph=True, create_graph=True)
    d_grad = grad(d_loss_, gan.D.parameters(), retain_graph=True, create_graph=True)

    gradient_vec = tr.cat([tr.flatten(w) for w in g_grad] + [tr.flatten(w) for w in d_grad])

    norm = tr.pow(tr.linalg.norm(gradient_vec), 2) / 2
    return norm

I know this involves a second backpropagation pass and is going to come with a performance hit, but currently evaluation is very, very slow (over a second per iteration).

I think the issue has to do with the list comprehension as it is the only part that is not pure pytorch code. I’ve looked at the documentation for the higher-level API beta, but it doesn’t seem like that’s what I’m looking for.

Is there a way to do this more efficiently?