VAEGAN - Multiple losses and multiple networks training

Hello,
I am at a loss when implementing the training part of a VAE+GAN. This is specifically because the full VAE+GAN uses different losses to train different parts of the VAE+GAN. The base paper I am trying to implement is here: https://arxiv.org/pdf/1512.09300. I tried several training procedures:

  • using a single backward pass by adding all losses but I am not sure the gradients are propagated correctly as the paper states that the different parts of the network should be trained with different combinations of losses
  • using multiple backward passes but with retain_graph=True despite knowing that this is frowned upon except for very specific cases. This also throws errors due to inplace operations (see code below).
  • alternating between zeroing gradients and backward passes for each network within the full VAE+GAN, but this throws propagation errors as the gradients are already freed by previous passes.
    As much as possible, I would like to avoid having to do multiple forward passes. I tried detaching tensors from the computation graphs but to no avail.

Here is an example with an MNIST dataset and the multiple backward passes case.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader

class Encoder(nn.Module):
    def __init__(self, in_channels, z_dim):
        super(Encoder, self).__init__()
        self.layers = 3 
        kernel_size = 5
        stride = 2
        padding = 1
        self.conv = []
        current_in = in_channels
        current_out = 16
        for i in range(self.layers):
            if i == self.layers - 1:
                self.conv.append(
                    nn.Conv2d(current_in, current_out, kernel_size, stride, padding, bias=False)
                )
            else:
                self.conv.append(
                    nn.Conv2d(current_in, current_out, kernel_size, stride, padding, bias=False)
                )
                self.conv.append(nn.BatchNorm2d(current_out))
                self.conv.append(nn.LeakyReLU(0.2))
                self.conv.append(nn.Dropout(0.3))

            current_in = current_out
            current_out *= 2
        current_out = current_out // 2
        self.conv = nn.Sequential(*self.conv)
        self.var = nn.Linear(current_out * 2 * 2, z_dim)
        self.mu = nn.Linear(current_out * 2 * 2, z_dim)
        self.flatten = nn.Flatten()

    def forward(self, x):
        x = self.conv(x)
        # mu = self.mu(x.view)
        mu = self.mu(self.flatten(x))
        var = self.var(self.flatten(x))
        return mu, var

class Decoder(nn.Module):
    def __init__(self, z_dim, target_shape):
        super(Decoder, self).__init__()
        self.base_channels = 256
        self.base_feature_maps = 4
        self.layers = 3
        self.fc = nn.Linear(z_dim, self.base_channels * self.base_feature_maps * self.base_feature_maps)
        self.unflatten = nn.Unflatten(
            1, 
            (self.base_channels, self.base_feature_maps, self.base_feature_maps))
        self.conv = []
        current_in = self.base_channels
        kernel_size = 4
        stride = 2
        padding = 1
        for i in range(self.layers):
            if i == self.layers - 1:
                current_out = 1
                self.conv.append(
                    nn.ConvTranspose2d(current_in, current_out, kernel_size, stride, padding, bias=False),
                )
                self.conv.append(nn.Upsample(target_shape))
                self.conv.append(nn.Sigmoid())
            else:
                current_out = max(current_in // 2, 1)
                self.conv.append(
                    nn.ConvTranspose2d(current_in, current_out, kernel_size, stride, padding, bias=False),
                )
                self.conv.append(nn.BatchNorm2d(current_out))
                self.conv.append(nn.LeakyReLU(0.2))
            current_in = current_out
        self.conv = nn.Sequential(*self.conv)

    def forward(self, z):
        x = self.fc(z)
        x = self.unflatten(x)
        x = self.conv(x)
        return x

class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        self.layers = 3 
        kernel_size = 5
        stride = 2
        padding = 1
        self.conv = []
        current_in = in_channels
        current_out = 16
        for i in range(self.layers):
            if i == self.layers - 1:
                self.conv.append(
                    nn.Conv2d(current_in, current_out, kernel_size, stride, padding, bias=False)
                )
            else:
                self.conv.append(
                    nn.Conv2d(current_in, current_out, kernel_size, stride, padding, bias=False)
                )
                self.conv.append(nn.BatchNorm2d(current_out))
                self.conv.append(nn.LeakyReLU(0.2))
                self.conv.append(nn.Dropout(0.3))

            current_in = current_out
            current_out *= 2
        current_out = current_out // 2
        self.conv = nn.Sequential(*self.conv)
        self.fc = nn.Linear(current_out * 2 * 2, 1)
        self.aux = nn.Linear(current_out * 2 * 2, 64)
        self.flatten = nn.Flatten()
        self.activation = nn.Sigmoid()
        
    def forward(self, x):
        x = self.conv(x)
        x_aux = self.aux(self.flatten(x))
        x = self.activation(self.fc(self.flatten(x)))
        return x, x_aux

class VAEGAN(nn.Module):
    def __init__(self, input_shape, z_dim):
        super(VAEGAN, self).__init__()
        c, h, w = input_shape
        self.encoder = Encoder(c, z_dim)
        self.decoder = Decoder(z_dim, (h, w))
        self.discriminator = Discriminator(c)

    def forward(self, x):
        mu, log_var = self.encoder(x)
        eps = torch.randn_like(mu)
        std = torch.exp(0.5 * log_var)
        z = mu + eps * std # z=Enc(x)
        x_rec = self.decoder(z) # rec=Dec(z)
        z_p = torch.randn_like(z) 
        x_p = self.decoder(z_p) # x_p=Dec(z_p)
        dis_x, dis_x_aux = self.discriminator(x)
        dis_x_rec, dis_x_rec_aux = self.discriminator(x_rec)
        dis_x_p, dis_x_p_aux = self.discriminator(x_p)
        return dis_x, dis_x_rec, dis_x_p, dis_x_aux, dis_x_rec_aux, dis_x_p_aux, mu, log_var

img_size = 28 
batch_size = 32

dataloader = DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [
                transforms.Resize(img_size), 
                transforms.ToTensor(), 
            ]
        ),
    ),
    batch_size=batch_size,
    shuffle=True,
)

z_dim = 16
vaegan = VAEGAN((1, img_size, img_size), z_dim)
opt_enc = optim.Adam(vaegan.encoder.parameters(), lr=3e-4)
opt_dec = optim.Adam(vaegan.decoder.parameters(), lr=3e-4)
opt_dis = optim.Adam(vaegan.discriminator.parameters(), lr=3e-4)

epochs = 100
gamma = 0.3

for epoch in range(epochs):
    train_loss = 0
    for i, (imgs, targets) in enumerate(dataloader):
        with torch.autograd.detect_anomaly(): # to follow the trace back to the problematic tensor
            # soft labels
            fake = torch.full_like(targets, 0.1).unsqueeze(1).float()
            real = torch.full_like(targets, 0.9).unsqueeze(1).float()
            mu, log_var = vaegan.encoder(imgs)
            eps = torch.randn_like(mu)
            std = torch.exp(0.5 * log_var)
            z = mu + eps * std
            x_rec = vaegan.decoder(z)
            z_p = torch.randn_like(z)
            x_p = vaegan.decoder(z_p)
            dis_x, _ = vaegan.discriminator(imgs)
            dis_x_rec, _ = vaegan.discriminator(x_rec.detach())
            dis_x_p, _ = vaegan.discriminator(x_p.detach()) 

            # 1. Train Discriminator
            opt_dis.zero_grad()
            gan_loss = F.binary_cross_entropy(dis_x, real) + F.binary_cross_entropy(dis_x_rec, fake) + F.binary_cross_entropy(dis_x_p, fake)
            gan_loss.backward()
            nn.utils.clip_grad_norm_(vaegan.discriminator.parameters(), 5)
            opt_dis.step()

            # 2. Get discriminator features for generator training (requires fresh forward pass)
            dis_x, dis_x_aux = vaegan.discriminator(imgs)
            dis_x_rec, dis_x_rec_aux = vaegan.discriminator(x_rec)
            
            # 3. Train Encoder
            opt_enc.zero_grad()
            kld = -0.5 * torch.sum(1 + log_var - torch.pow(mu, 2) - torch.exp(log_var))
            mse = F.mse_loss(dis_x_aux, dis_x_rec_aux)
            enc_loss = kld + mse
            enc_loss.backward(retain_graph=True)  # Need to retain for decoder training
            nn.utils.clip_grad_norm_(vaegan.encoder.parameters(), 5)
            opt_enc.step()
            
            # 4. Train Decoder/Generator
            opt_dec.zero_grad()
            gen_loss = F.binary_cross_entropy(dis_x_rec, real)
            dec_loss = gamma * mse + gen_loss
            dec_loss.backward() # this raises the tensor version error
            nn.utils.clip_grad_norm_(vaegan.decoder.parameters(), 5)
            opt_dec.step()
            
            print(f"Epoch {epoch} Batch {100*i/len(dataloader):.2f}% Enc loss {enc_loss.item():.3f} Dec loss {dec_loss:.3f} Dis loss {gan_loss.item():.3f}\n", end="")

The traceback leads to the encoder and says I have a in-place operation. Could it be related to the Flatten? Or to me using “stale” data when training after the second pass?
Thanks in advance!

Edit: sorry for the very long code block but I wanted to include all possible errors. I tried this training procedure with simpler networks (1D data) and I get no errors so I think it may come from the network and not the training.

When continuing to investigate, I found that replacing the forward of the encoder by something silly like:

def forward(self, x):
        x = self.conv(x)
        r_data = torch.randn_like(self.flatten(x))
        # mu = self.mu(self.flatten(x))
        # var = self.var(self.flatten(x))
        mu = self.mu(r_data)
        var = self.var(r_data)
        return mu, var

and effectively bypassing the flatten and keeping the dimensions the error disappears.
I do not really understand how the flatten layer acts in-place however…

Hi Martin!

I haven’t looked at your code in detail, but the likely cause is as follows:

opt_enc.step() performs inplace modifications on the model parameters it is optimizing.

But you then perform a second .backward() on the graph that you retained and modified
inplace, hence the error.

It appears that opt_dec.step() does not optimize any of the encoder parameters. If this
is the case, then the encoder portion of the graph you backward through a second time in
dec_loss.backward() need not be in the decoder graph.

You are already computing dis_x_rec twice – the first time detached from the encoder
graph. Would it be possible to keep the detached version (instead of discarding it) and
use it to compute gen_loss and hence dec_loss? Then dec_loss.backward() won’t
re-traverse the inplace-modified encoder graph.

Of course, you should do some debugging to make sure that this is actually what is happening
(and that there aren’t similar errors that would have shown up after the the first error is raised).

You can find a discussion of how to debug and fix inplace-modification errors in this post.

I always recommend understanding what is going on and addressing it directly, but pytorch
does offer a sweep-inplace-modification-errors-under-the-rug context manager.

In reference to your second post:

creates a new tensor that doesn’t depend on x (and doesn’t carry requires_grad = True).

So when you backpropagate through the encoder, you backpropagate through self.mu
and self.var, but not through self.conv. If just parameters in self.conv are the ones that
cause problems when modified inplace, you’ve avoided the problem by not backpropagating
through them.

Best.

K. Frank

Hi @KFrank,
I implemented some of the changes you proposed and got this training:

      # soft labels
      fake = torch.full_like(targets, 0.1).unsqueeze(1).float()
      real = torch.full_like(targets, 0.9).unsqueeze(1).float()
      mu, log_var = vaegan.encoder(imgs)
      eps = torch.randn_like(mu)
      std = torch.exp(0.5 * log_var)
      z = mu + eps * std
      # detach z 
      z_detached = z.detach()
      x_rec = vaegan.decoder(z)
      x_rec_detached = x_rec.detach() # for use in gan loss
      # ask for a fresh x_rec from the detached z for use in gen loss
      x_rec_detach_gen_only = vaegan.decoder(z_detached)
      z_p = torch.randn_like(z)
      x_p = vaegan.decoder(z_p)
      # run the forwards of the discriminator
      dis_x, dis_x_aux = vaegan.discriminator(imgs) # > for gan and mse
      dis_x_rec, _ = vaegan.discriminator(x_rec_detached) # > for gan
      _, dis_x_rec_aux = vaegan.discriminator(x_rec) # > for mse
      dis_x_rec_gen_only, _ = vaegan.discriminator(x_rec_detach_gen_only) # > for gen
      dis_x_p, _ = vaegan.discriminator(x_p) # > for gan
      # compute losses
      gan_loss = F.binary_cross_entropy(dis_x, real) + F.binary_cross_entropy(dis_x_rec, fake) + F.binary_cross_entropy(dis_x_p, fake)
      kld = -0.5 * torch.sum(1 + log_var - torch.pow(mu, 2) - torch.exp(log_var))
      mse = F.mse_loss(dis_x_aux, dis_x_rec_aux)
      enc_loss = kld + mse
      gen_loss = F.binary_cross_entropy(dis_x_rec_gen_only, real)
      dec_loss = gamma * mse + gen_loss
      # zero grads and train
      opt_enc.zero_grad()
      opt_dec.zero_grad()
      opt_dis.zero_grad()
      enc_loss.backward(retain_graph=True)
      gen_loss.backward()
      gan_loss.backward()
      opt_enc.step()
      opt_dec.step()
      opt_dis.step()

which works fine thank you!
I took some time to create a flow diagram of what was expected below. Dotted tensors are leaves and colors show the gradient flows.


In the end, the only things I do not like are the multiple forward passes on the discriminator and the retain_graph=True remaining, but their use seems justified here.
Tell me if you think there is a cleaner or more principled solution you think of.
Best regards,
Martin

Hi Martin!

I wouldn’t worry about retain_graph = True – it has legitimate use cases and looks
reasonable here. (The problem is that it is sometimes used to hide a substantive underlying
issue. but it’s often appropriate with GANs.)

I’m not commenting on your architecture in particular – just making some general comments
about GAN training:

When training GANs, duplicating the forward pass is often the right – or at least the most
practical – thing to do. But there are a couple of ways to avoid it.

Let’s say you have a generator feeding a discriminator and use something like
BCEWithLogitsLoss as your loss criterion. When the discriminator get the right answer
(so the discriminator is doing well, but the generator is doing poorly), the loss is low.
You can backpropagate this loss to train the discriminator. But if you insert a custom
layer where the output of the generator is passed to the discriminator that flips the sign
of the gradient during backpropagation, you can then use the same forward / backward
pass also to train the generator.

(Note, although I think this approach is valid, I think that it’s not optimal. This is because
flipping the sign of the gradient is not mathematically the same as flipping the label used
in BCEWithLogitsLoss, so to speak, from “fake” to “real”. I don’t have evidence for this,
but I would expect that backpropagating the flipped label – at the cost of running a second
backward pass – will train the generator more effectively than backpropagating the flipped
sign – which would only require a single backward pass.)

The other thing you can do is perform a forward pass for the generator, make the output
of the generator a leaf variable, perform the forward pass for the discriminator, using the
output of the generator as a leaf-variable input. You can then call .backward() for the
discriminator and the generator separately and have more granular control over the two
“half” backward passes, for example how (and whether) they populate various .grads.

Best.

K. Frank

Thanks @KFrank.
Just to be sure I understand:

This is essentially what I did with detaching tensors and using an optimizer for each part of the whole VAEGAN, or is there another way to do it?
Also:

You could but that would probably be best for another topic on data generation :smile:
Anyway thanks again,
Martin