Hi Jeet!
First, as I am sure you are aware, there is no problem with having one
Module
contain other Module
s.
Now what your forward()
method should look like depends on the
details of how you want to train your GAN. (There are a number of
approaches with varying degrees of nuance.)
In one simple scheme, you perform a forward pass in which you both
generate a fake image in your Generator
Module
and then pass
that fake image (possibly together with a real image) through your
Discriminator
Module
on to the final loss function that rewards your
Discriminator
for successfully distinguishing fake from real.
The issue is that you also want to reward your Generator
for successfully
fooling your Discriminator
into mistaking fake for real. That is, you
want to train – in the same backward pass – your Discriminator
and
Generator
“in opposite directions.” One way to do this is to flip the sign
of the gradients that are backpropagated back through the Generator
.
You can to this by introducing (after, of course, having written it) a layer,
packaged as a custom torch.autograd.Function
, that passes the tensor
through unchanged on the forward pass, but that then flips the sign of the
gradient on the backward pass.
You would then do something like this:
class MyGANModule (nn.Module) :
def __init__ (self, in_features, z_dim, ing_dim) :
super().__init__()
self.my_gen = Generator (z_dim, img_dim)
self.gradient_flipper = CustomGradientSignFlippingFunction()
self.my_disc = Discriminator (in_features)
def forward (self, x) :
x = self.my_gen (x)
x = self.gradient_flipper.apply (x)
x = self.my_disc (x)
return x
(You could, alternatively, package Generator
,
CustomGradientSignFlippingFunction
, and Discriminator
together
as a Sequential
, if you prefer the Sequential
style.)
So my_gen
and my_disc
are themselves both Module
s, but they are
contained in (said another way, are properties of) your MyGANModule
Module
.
MyGANModule
has its own forward()
method, but when that forward()
method calls, for example, my_gen (x)
, pytorch’s infrastructure then
calls Generator
’s own forward()
method on x
.
Best.
K. Frank