Not sure how to implement GAN optimizer

I’m trying to implement the ConOpt optimizer for GANs, which combines stochastic gradient descent for both the generator and discriminator with a gradient penalty regularizer. I know that standard practice is usually to give the generator and discriminator separate optimizers, but for this algorithm it is necessary to add the squared norm of the gradients from both models to each loss. Therefore, it is necessary to have something like the closure argument in the Optimizer.step() method, but somehow have two closures, one for each model. Does anyone have a suggested way of doing this with only one closure?

I don’t know if the only way to do it would be to, for example, make a new class GANOptimizer where the step() method requires two closures, one for each loss function. That’s pretty messy but I’m not sure of another solution

If you want to implement a min-max loss using a single optimizer, have you tried reverse grad ? Not sure if that would help…

I’m not sure I understand your reply. Ideally, I’d like to have separate optimizers for the generator and discriminator, but the loss depends on both for ConOpt, even if the optimizer is only optimizing one set of parameters.

Sorry i misinterpreted the question. So as i understand you are trying to make the Consensus Optimizer, but in pytorch. The code would be something like this:

class ConsensusOptimizer(object):
    def __init__(self, learning_rate, alpha=0.1, beta=0.9, eps=1e-8):
        self.optimizer = torch.optim.RMSPropOptimizer(learning_rate)
        self._eps = eps
        self._alpha = alpha
        self._beta = beta

    def conciliate(self, d_loss, g_loss, d_vars, g_vars, global_step=None):
        alpha = self._alpha
        beta = self._beta

        # Compute gradients
        d_grads = torch.autograd.grad(d_loss, d_vars)
        g_grads = torch.autograd.grad(g_loss, g_vars)

        # Merge variable and gradient lists
        variables = d_vars + g_vars
        grads = d_grads + g_grads

        # Reguliarizer

        reg = 0.5 * sum(
            torch.sum(torch.square(g)) for g in grads

        # Jacobian times gradiant
        Jgrads = torch.autograd.grad(reg, variables)
        for param in variables:

I assume you call this after you call loss.backward(). This is untested code and you prob need to make some changes.
See more info here

That’s more or less what I’ve tried already but it doesn’t work. My understanding is that the best way to implement this is as an optimizer, hence my original post.