GAN-training - Updating the generator

I am working on implementing a GAN.

For computing the loss of the generator, I compute both the negative probabilities that the discriminator mis-classifies an all-real minibatch and an all-generated-fake minibatch. Then, I back-propagate both parts sequentially and finally apply the step function.

Calculating and back-propagating the part of the loss which is a function of the mis-classifications of the generated fake data seems straight forward, since during back-propagation of that loss term, the backward path leads through the generator who has produced the fake data in the first place.

However, classification of all-real-data minibatches does not involve passing data through the generator. Therefore, I was wondering whether the following code snipped would still calculate gradients for the generator or whether it would not calculate any gradients at all (since the backward path does not lead through the generator and the discriminator is in eval-mode while updating the generator)?

# Update generator #
net.generator.train()
net.discriminator.eval()
net.generator.zero_grad()

# All-real minibatch
x_real = get_all_real_minibatch()
y_true = torch.full((batch_size,), label_fake).long()  # Pretend true targets were fake
y_pred = net.discriminator(x_real)  # Produces softmax probability distribution over (0=label_fake,1=label_real)
            
loss_real = NLLLoss(torch.log(y_pred), y_true) 
loss_real.backward()
optimizer_generator.step()

If this doesn’t work as intended, how could I make it work? Thanks in advance!

Assuming get_all_real_minibatch() loads some real sample without using the generator, then the generator won’t get any gradients, as it’s not used during the current computation.

The generator will get the gradients for a “fake input” (created by the generator) and for “real targets”, so that the generator can be updated to learn how to fool the discriminator.

Could you explain your use case a bit?
Currently the posted code doesn’t include the generator, so I’m not sure how gradients should be calculated and what they should represent.

I am working on implementing a variant of the architecture presented in this research paper. So, the task is to learn translations of word embeddings from a source language (word embedding space) to a target language (word embedding space). For that, the authors used a GAN, where the generator implemented the translation of source embeddings and the discriminator had to judge which embeddings were native to the target space and which were the fake translations.
The authors of the paper indicate in equation 4 that the generator is trained on both the probability of the discriminator misclassifying the real data and the fake data.

I had a look at the source code associated with the aforementioned research paper and I saw that the authors combine both loss terms by passing the predictions of the discriminator on both real and fake data together to the loss function when computing the generator’s loss. I am assuming that this will indeed assure that the loss term resulting from the misclassification of real data will indeed be taken into account when computing the generator’s gradients.

You are right in assuming that in the example code I posted above, the get_all_real_minibatch() method is supposed to sample a minibatch of unmodified target language word embeddings, and hence unmodified data the discriminator is supposed to accept as (real) target data (which has of course not been modified by the generator in any way).

The generator itself (omitting unnecessary further complications that are not relevant to the following) consists of an encoder and a decoder, which in the simplest form boils down to a concatenation of dense layers w_in through w_out. So, then the modified code snippet including the generator would become the following:

# Update generator #
net.generator.train()
net.discriminator.eval()
net.generator.zero_grad()

# All-real minibatch
x_real = get_all_real_minibatch()
y_true = torch.full((batch_size,), label_fake).long()  # Pretend true targets were fake
y_pred = net.discriminator(x_real)  # Produces softmax probability distribution over (0=label_fake,1=label_real)
loss_real = NLLLoss(torch.log(y_pred), y_true)

# All-fake minibatch 
x_src = get_source_language_embedding_minibatch()
x_fake = net.generator(x_src)  # Translate source embeddings into target space
y_true = torch.full((batch_size,), label_real).long()  # Pretend fake data was real target data
y_pred = net.discriminator(x_fake)

loss = NLLLoss(torch.log(y_pred), y_true) + loss_real
loss.backward()
optimizer_generator.step()

In this updated version of the code snippet, is it the case that actually the loss_real is taken into consideration when computing the gradients of the generator (since now it’s added to the loss_fake term that is going to be backpropagated through the generator)? Or are further adjustments necessary?

If I understand the code correctly, get_dis_xy will return two tensors, which were not created by the generator (or if they were, then they are detached in these lines of code, since they are re-wrapped in the deprecated Variable class).
You could check for gradients in the generator after the backward() op was called in this line and should see no updates in the generator’s gradients.

You are correct, that both losses are added. However, Autograd will use the backpropagation rules during the backward call, such that the generator still won’t get any information from the discriminator loss, if it wasn’t used in the creation of this loss.

Here is a very simple example, which mimics your workflow.
In the first part I’ll accumulate the losses and call loss.backward(), while in the second use case I call backward on each loss separately.
As you can see, the loss_real.backward() call does not create any gradients in the generator.
The final gradients are equal to the first approach, which means that the generator gets the gradient from the loss_fake.backward() call.

# Accumulate losses
torch.manual_seed(2809)

generator = nn.Linear(1, 1, bias=False)
discriminator = nn.Linear(1, 2, bias=False)

criterion = nn.CrossEntropyLoss()

x_real = torch.randn(1, 1)
y_pred_real = discriminator(x_real)
loss_real = criterion(y_pred_real, torch.zeros(y_pred_real.size(0)).long())

x_src = torch.randn(1, 1)
x_fake = generator(x_src)
y_pred_fake = discriminator(x_fake)
loss_fake = criterion(y_pred_fake, torch.ones(y_pred_fake.size(0)).long())

loss = loss_real + loss_fake
loss.backward()

print(generator.weight.grad)
print(discriminator.weight.grad)

# Separate calls
generator.zero_grad()
discriminator.zero_grad()

y_pred_real = discriminator(x_real)
loss_real = criterion(y_pred_real, torch.zeros(y_pred_real.size(0)).long())
loss_real.backward()

print(generator.weight.grad)
print(discriminator.weight.grad)

x_fake = generator(x_src)
y_pred_fake = discriminator(x_fake)
loss_fake = criterion(y_pred_fake, torch.ones(y_pred_fake.size(0)).long())
loss_fake.backward()

print(generator.weight.grad)
print(discriminator.weight.grad)
1 Like

I think the first half of x contains the transformations of source embeddings produced by the generator. As I understand it, the original source embeddings are retrieved from the source embeddings-set in line 69 and the translation of these original source embeddings into the target space is done by the generator self.mapping in line 71. Then, the translated source embeddings and the target embeddings get concatenated in line 75 and jointly returned in tensor x in line 81 for classification of both generator-generated translations and target data by the discriminator in line 117. In the next line, the loss is then computed on both the generator-generated translations and the pure target data.

But I understand now that gradients will only be computed for a module with respect to a loss if the loss has directly been caused by the given module. Thank you very much for the clarification!