Can I use loss.backward for Backpropagation with respect to the weights of another neural network?

Hi Jeet!

Almost, but with an important detail. When you backpropagate from
the discriminator back up into the generator, you need to flip the sign
of the gradient.

(Also, I see so benefit to storing the network weights in a separate
“large vector.” Just leave them in the networks themselves.)

I’ve never built a GAN, so I am fuzzy on the details, but the basic
idea is as follows:

You have a generator network (Gen) that produces “fake” images
that look real and you have a discriminator network (Disc) whose
job it is to tell the fake images apart from teal ones.

So Disc is an ordinary classifier – “fake” vs. “real” – and you can
train it with something like BCEWithLogitsLoss.

But the idea of a GAN is to also train Gen to generate fake images
that fool Disc into classifying them as real. The scheme is to train
Disc so that the loss from Disc goes down but train Gen so that the
loss from Disc goes up.

You can do this as follows:

Feed a real image into Disc and calculate the classification loss.
Backpropagate it through Disc, updating Disc’s weights.

Now feed some random input into Gen. Gen acts sort of like a decoder
and “decodes” the random input into a fake image. The fake image output
by Gen depends on Gen’s weights and has requires_grad = True.
Feed this fake image into Disc, calculate the classification loss and
backpropagate it. This also updates Disc’s weights, continuing to train
it to distinguish fake from real.

The key point:

When we further backpropagate Disc’s classification loss for the
fake image through Gen – which we can do because the input to
Disc came from Gen, depends on Gen’s weights, and carries
requires_grad = True – we flip the sign of the gradient. This
is because we want to penalize Gen if Disc did well, and reward
Gen if Disc did poorly when classifying the fake image. That is,
we train Gen and Disc at cross-purposes with one another.

One convenient way to effect this gradient sign-flip is to interpose
a “sign-flip” Function between Gen and Disc. During the forward
pass, the sign-flip Function simply passes its input through unchanged.
(That is, we pass the fake image generated by Gen unchanged into
Disc.) But on the backward pass the sign-flipper takes the gradient
it’s given and flips its sign before sending it on to Gen for further
backpropagation.

I’m not aware that pytorch offers a pre-packaged sign-flip Function,
but it’s easy enough to write one.

Some additional discussion about flipping the gradient’s sign can be
found in this thread:

Best.

K. Frank

1 Like