Essentially, we have an encoder network (take input belonging to some class and output latents) and a classifier that tries to predict the class of the input solely based on its latents. The minimax game here is for the encoder to seek for class-independent latents and the classifier to be able to correctly classify them.

The original implementation avoids using retain_graph by first computing the latents and optimizing the classifier with a classification loss, and then recomputing the latents (with exactly the same input if the data augmentation option is disabled) and optimizing the encoder with the negative classification loss (the same loss function but recomputed to use the updated parameters of the classifier).

My question is why do we have to use retain_graph if the optimizers are on disjoint sets of parameters? And, is using retain_graph the same as recomputing the latents in this case?

You need to use retain_graph because .backward() goes through the whole graph (both encode/decoder here). And so if you want to be able to backward in the decoder again you need to retain_graph.

You can use retain_graph if you don’t change any value required by the backward. In particular here, the optimizer step() changes the parameters inplace and might prevent you from being able to backward a second time (make sure to run v1.5.0+ as this was fixed recently).

Both of them will just work very similarly. You will either do extra work during a backward that you don’t care of an extra forward.

Regarding to your answer, the .backward() doesn’t go through the decoder in the case of optimizing the classifier since its weights are not used at any point of the computation graph.

My confusion was more about that you have to explicitely detach the input variable even if it’s not used in the optimizer. For me it’s then more natural to compute the classifier loss detaching the latents and then without need of recomputing the latents, compute the same loss to optimize the encoder.

The simplified code of the music translation network looks then something like this. And you don’t need to recompute the latents.

# Optimize D - discriminator right
z = self.encoder(x)
z_logits = self.discriminator(z.detach())
discriminator_right = F.cross_entropy(z_logits, torch.tensor([dset_num] * x.size(0)).long()).mean()
self.d_optimizer.zero_grad()
discriminator_right.backward()
self.d_optimizer.step()
# optimize G - reconstructs well, discriminator wrong
y = self.decoder(x, z)
z_logits = self.discriminator(z)
discriminator_wrong = - F.cross_entropy(z_logits, torch.tensor([dset_num] * x.size(0)).long()).mean()