Hi Jeet!
Almost, but with an important detail. When you backpropagate from
the discriminator back up into the generator, you need to flip the sign
of the gradient.
(Also, I see so benefit to storing the network weights in a separate
“large vector.” Just leave them in the networks themselves.)
I’ve never built a GAN, so I am fuzzy on the details, but the basic
idea is as follows:
You have a generator network (Gen
) that produces “fake” images
that look real and you have a discriminator network (Disc
) whose
job it is to tell the fake images apart from teal ones.
So Disc
is an ordinary classifier – “fake” vs. “real” – and you can
train it with something like BCEWithLogitsLoss
.
But the idea of a GAN is to also train Gen
to generate fake images
that fool Disc
into classifying them as real. The scheme is to train
Disc
so that the loss from Disc
goes down but train Gen
so that the
loss from Disc
goes up.
You can do this as follows:
Feed a real image into Disc
and calculate the classification loss.
Backpropagate it through Disc
, updating Disc
’s weights.
Now feed some random input into Gen
. Gen
acts sort of like a decoder
and “decodes” the random input into a fake image. The fake image output
by Gen
depends on Gen
’s weights and has requires_grad = True
.
Feed this fake image into Disc
, calculate the classification loss and
backpropagate it. This also updates Disc
’s weights, continuing to train
it to distinguish fake from real.
The key point:
When we further backpropagate Disc
’s classification loss for the
fake image through Gen
– which we can do because the input to
Disc
came from Gen
, depends on Gen
’s weights, and carries
requires_grad = True
– we flip the sign of the gradient. This
is because we want to penalize Gen
if Disc
did well, and reward
Gen
if Disc
did poorly when classifying the fake image. That is,
we train Gen
and Disc
at cross-purposes with one another.
One convenient way to effect this gradient sign-flip is to interpose
a “sign-flip” Function
between Gen
and Disc
. During the forward
pass, the sign-flip Function
simply passes its input through unchanged.
(That is, we pass the fake image generated by Gen
unchanged into
Disc
.) But on the backward pass the sign-flipper takes the gradient
it’s given and flips its sign before sending it on to Gen
for further
backpropagation.
I’m not aware that pytorch offers a pre-packaged sign-flip Function
,
but it’s easy enough to write one.
Some additional discussion about flipping the gradient’s sign can be
found in this thread:
Best.
K. Frank