How to pack two modules of GAN(generator and discriminator) in a single module?

class Discriminator(nn.Module):

    def __init__(self, in_features):

        super().__init__()

        self.disc = nn.Sequential(

            nn.Linear(in_features, 128),

            nn.LeakyReLU(0.01),

            nn.Linear(128, 1),

            nn.Sigmoid(),

        )

    def forward(self, x):

        return self.disc(x)

class Generator(nn.Module):

    def __init__(self, z_dim, img_dim):

        super().__init__()

        self.gen = nn.Sequential(

            nn.Linear(z_dim, 256),

            nn.LeakyReLU(0.01),

            nn.Linear(256, img_dim),

            nn.Tanh(), 

        )

    def forward(self, x):

        return self.gen(x) 

How should I design the forward method in that case if the generator and the discriminator both are considered a single module?

Hi Jeet!

First, as I am sure you are aware, there is no problem with having one
Module contain other Modules.

Now what your forward() method should look like depends on the
details of how you want to train your GAN. (There are a number of
approaches with varying degrees of nuance.)

In one simple scheme, you perform a forward pass in which you both
generate a fake image in your Generator Module and then pass
that fake image (possibly together with a real image) through your
Discriminator Module on to the final loss function that rewards your
Discriminator for successfully distinguishing fake from real.

The issue is that you also want to reward your Generator for successfully
fooling your Discriminator into mistaking fake for real. That is, you
want to train – in the same backward pass – your Discriminator and
Generator “in opposite directions.” One way to do this is to flip the sign
of the gradients that are backpropagated back through the Generator.

You can to this by introducing (after, of course, having written it) a layer,
packaged as a custom torch.autograd.Function, that passes the tensor
through unchanged on the forward pass, but that then flips the sign of the
gradient on the backward pass.

You would then do something like this:

class MyGANModule (nn.Module) :
    def __init__ (self, in_features, z_dim, ing_dim) :
        super().__init__()
        self.my_gen = Generator (z_dim, img_dim)
        self.gradient_flipper = CustomGradientSignFlippingFunction()
        self.my_disc = Discriminator (in_features)
    
    def forward (self, x) :
        x = self.my_gen (x)
        x = self.gradient_flipper.apply (x)
        x = self.my_disc (x)
        return  x

(You could, alternatively, package Generator,
CustomGradientSignFlippingFunction, and Discriminator together
as a Sequential, if you prefer the Sequential style.)

So my_gen and my_disc are themselves both Modules, but they are
contained in (said another way, are properties of) your MyGANModule
Module.

MyGANModule has its own forward() method, but when that forward()
method calls, for example, my_gen (x), pytorch’s infrastructure then
calls Generator's own forward() method on x.

Best.

K. Frank

1 Like

Hi Frank,

Thanks a lot again for the detailed explanation. You are correct that this flipping of gradient is required if we need to train the Generator in the same backward pass. In this regard let me tell you, that I am using a different formula to train the GAN. The weights of the genrators and the Discriminator will be updated taking help of that gradient descent formula. However, the formula requires an input which is a Jacobian matrix. The Jacobian matrix will contain the first order partial derivative of the loss functions of the generator and the discriminator with respect to all the weights (Both Generator and Discriminator).
Hence, I was thinking that if I can have both the modules of generator and the discriminator in a single module then I could have the partial derivative of the respective loss functions with respect to all the weights. For example- LossGenerator.backward() and LossDiscriminator.backward() but with repect to all the weights in the model.
However, the problem here is in GAN, the losses are calculated seperately for the generator and the discriminator. It is because the real data(real image or img_dim above) is fed only to the discriminator but not to the generator. The output of the discriminator is fed to the generator and the loss of the generator is calculated. Since, the loses are calculated seperately, I have no clue how can I calculate the first order partial derivative for each of the loss functions with respect to all the weights.
In the MyGANModule module suggested by you above -
should I have 3 forward method?One for the generator, one for the discriminator and the other one combined (as you have written above)? But I am not sure if that is going to help. If I use MyGANModule to train the discriminator with the real image and find the loss, then can I find the gradients(1st order partial derivative) with respect to the weights of the generator as well?
Could you please provide a code snippet as an example for the same?