VAEGAN - Multiple losses and multiple networks training

Hello,
I am at a loss when implementing the training part of a VAE+GAN. This is specifically because the full VAE+GAN uses different losses to train different parts of the VAE+GAN. The base paper I am trying to implement is here: https://arxiv.org/pdf/1512.09300. I tried several training procedures:

  • using a single backward pass by adding all losses but I am not sure the gradients are propagated correctly as the paper states that the different parts of the network should be trained with different combinations of losses
  • using multiple backward passes but with retain_graph=True despite knowing that this is frowned upon except for very specific cases. This also throws errors due to inplace operations (see code below).
  • alternating between zeroing gradients and backward passes for each network within the full VAE+GAN, but this throws propagation errors as the gradients are already freed by previous passes.
    As much as possible, I would like to avoid having to do multiple forward passes. I tried detaching tensors from the computation graphs but to no avail.

Here is an example with an MNIST dataset and the multiple backward passes case.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader

class Encoder(nn.Module):
    def __init__(self, in_channels, z_dim):
        super(Encoder, self).__init__()
        self.layers = 3 
        kernel_size = 5
        stride = 2
        padding = 1
        self.conv = []
        current_in = in_channels
        current_out = 16
        for i in range(self.layers):
            if i == self.layers - 1:
                self.conv.append(
                    nn.Conv2d(current_in, current_out, kernel_size, stride, padding, bias=False)
                )
            else:
                self.conv.append(
                    nn.Conv2d(current_in, current_out, kernel_size, stride, padding, bias=False)
                )
                self.conv.append(nn.BatchNorm2d(current_out))
                self.conv.append(nn.LeakyReLU(0.2))
                self.conv.append(nn.Dropout(0.3))

            current_in = current_out
            current_out *= 2
        current_out = current_out // 2
        self.conv = nn.Sequential(*self.conv)
        self.var = nn.Linear(current_out * 2 * 2, z_dim)
        self.mu = nn.Linear(current_out * 2 * 2, z_dim)
        self.flatten = nn.Flatten()

    def forward(self, x):
        x = self.conv(x)
        # mu = self.mu(x.view)
        mu = self.mu(self.flatten(x))
        var = self.var(self.flatten(x))
        return mu, var

class Decoder(nn.Module):
    def __init__(self, z_dim, target_shape):
        super(Decoder, self).__init__()
        self.base_channels = 256
        self.base_feature_maps = 4
        self.layers = 3
        self.fc = nn.Linear(z_dim, self.base_channels * self.base_feature_maps * self.base_feature_maps)
        self.unflatten = nn.Unflatten(
            1, 
            (self.base_channels, self.base_feature_maps, self.base_feature_maps))
        self.conv = []
        current_in = self.base_channels
        kernel_size = 4
        stride = 2
        padding = 1
        for i in range(self.layers):
            if i == self.layers - 1:
                current_out = 1
                self.conv.append(
                    nn.ConvTranspose2d(current_in, current_out, kernel_size, stride, padding, bias=False),
                )
                self.conv.append(nn.Upsample(target_shape))
                self.conv.append(nn.Sigmoid())
            else:
                current_out = max(current_in // 2, 1)
                self.conv.append(
                    nn.ConvTranspose2d(current_in, current_out, kernel_size, stride, padding, bias=False),
                )
                self.conv.append(nn.BatchNorm2d(current_out))
                self.conv.append(nn.LeakyReLU(0.2))
            current_in = current_out
        self.conv = nn.Sequential(*self.conv)

    def forward(self, z):
        x = self.fc(z)
        x = self.unflatten(x)
        x = self.conv(x)
        return x

class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        self.layers = 3 
        kernel_size = 5
        stride = 2
        padding = 1
        self.conv = []
        current_in = in_channels
        current_out = 16
        for i in range(self.layers):
            if i == self.layers - 1:
                self.conv.append(
                    nn.Conv2d(current_in, current_out, kernel_size, stride, padding, bias=False)
                )
            else:
                self.conv.append(
                    nn.Conv2d(current_in, current_out, kernel_size, stride, padding, bias=False)
                )
                self.conv.append(nn.BatchNorm2d(current_out))
                self.conv.append(nn.LeakyReLU(0.2))
                self.conv.append(nn.Dropout(0.3))

            current_in = current_out
            current_out *= 2
        current_out = current_out // 2
        self.conv = nn.Sequential(*self.conv)
        self.fc = nn.Linear(current_out * 2 * 2, 1)
        self.aux = nn.Linear(current_out * 2 * 2, 64)
        self.flatten = nn.Flatten()
        self.activation = nn.Sigmoid()
        
    def forward(self, x):
        x = self.conv(x)
        x_aux = self.aux(self.flatten(x))
        x = self.activation(self.fc(self.flatten(x)))
        return x, x_aux

class VAEGAN(nn.Module):
    def __init__(self, input_shape, z_dim):
        super(VAEGAN, self).__init__()
        c, h, w = input_shape
        self.encoder = Encoder(c, z_dim)
        self.decoder = Decoder(z_dim, (h, w))
        self.discriminator = Discriminator(c)

    def forward(self, x):
        mu, log_var = self.encoder(x)
        eps = torch.randn_like(mu)
        std = torch.exp(0.5 * log_var)
        z = mu + eps * std # z=Enc(x)
        x_rec = self.decoder(z) # rec=Dec(z)
        z_p = torch.randn_like(z) 
        x_p = self.decoder(z_p) # x_p=Dec(z_p)
        dis_x, dis_x_aux = self.discriminator(x)
        dis_x_rec, dis_x_rec_aux = self.discriminator(x_rec)
        dis_x_p, dis_x_p_aux = self.discriminator(x_p)
        return dis_x, dis_x_rec, dis_x_p, dis_x_aux, dis_x_rec_aux, dis_x_p_aux, mu, log_var

img_size = 28 
batch_size = 32

dataloader = DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [
                transforms.Resize(img_size), 
                transforms.ToTensor(), 
            ]
        ),
    ),
    batch_size=batch_size,
    shuffle=True,
)

z_dim = 16
vaegan = VAEGAN((1, img_size, img_size), z_dim)
opt_enc = optim.Adam(vaegan.encoder.parameters(), lr=3e-4)
opt_dec = optim.Adam(vaegan.decoder.parameters(), lr=3e-4)
opt_dis = optim.Adam(vaegan.discriminator.parameters(), lr=3e-4)

epochs = 100
gamma = 0.3

for epoch in range(epochs):
    train_loss = 0
    for i, (imgs, targets) in enumerate(dataloader):
        with torch.autograd.detect_anomaly(): # to follow the trace back to the problematic tensor
            # soft labels
            fake = torch.full_like(targets, 0.1).unsqueeze(1).float()
            real = torch.full_like(targets, 0.9).unsqueeze(1).float()
            mu, log_var = vaegan.encoder(imgs)
            eps = torch.randn_like(mu)
            std = torch.exp(0.5 * log_var)
            z = mu + eps * std
            x_rec = vaegan.decoder(z)
            z_p = torch.randn_like(z)
            x_p = vaegan.decoder(z_p)
            dis_x, _ = vaegan.discriminator(imgs)
            dis_x_rec, _ = vaegan.discriminator(x_rec.detach())
            dis_x_p, _ = vaegan.discriminator(x_p.detach()) 

            # 1. Train Discriminator
            opt_dis.zero_grad()
            gan_loss = F.binary_cross_entropy(dis_x, real) + F.binary_cross_entropy(dis_x_rec, fake) + F.binary_cross_entropy(dis_x_p, fake)
            gan_loss.backward()
            nn.utils.clip_grad_norm_(vaegan.discriminator.parameters(), 5)
            opt_dis.step()

            # 2. Get discriminator features for generator training (requires fresh forward pass)
            dis_x, dis_x_aux = vaegan.discriminator(imgs)
            dis_x_rec, dis_x_rec_aux = vaegan.discriminator(x_rec)
            
            # 3. Train Encoder
            opt_enc.zero_grad()
            kld = -0.5 * torch.sum(1 + log_var - torch.pow(mu, 2) - torch.exp(log_var))
            mse = F.mse_loss(dis_x_aux, dis_x_rec_aux)
            enc_loss = kld + mse
            enc_loss.backward(retain_graph=True)  # Need to retain for decoder training
            nn.utils.clip_grad_norm_(vaegan.encoder.parameters(), 5)
            opt_enc.step()
            
            # 4. Train Decoder/Generator
            opt_dec.zero_grad()
            gen_loss = F.binary_cross_entropy(dis_x_rec, real)
            dec_loss = gamma * mse + gen_loss
            dec_loss.backward() # this raises the tensor version error
            nn.utils.clip_grad_norm_(vaegan.decoder.parameters(), 5)
            opt_dec.step()
            
            print(f"Epoch {epoch} Batch {100*i/len(dataloader):.2f}% Enc loss {enc_loss.item():.3f} Dec loss {dec_loss:.3f} Dis loss {gan_loss.item():.3f}\n", end="")

The traceback leads to the encoder and says I have a in-place operation. Could it be related to the Flatten? Or to me using “stale” data when training after the second pass?
Thanks in advance!

Edit: sorry for the very long code block but I wanted to include all possible errors. I tried this training procedure with simpler networks (1D data) and I get no errors so I think it may come from the network and not the training.

When continuing to investigate, I found that replacing the forward of the encoder by something silly like:

def forward(self, x):
        x = self.conv(x)
        r_data = torch.randn_like(self.flatten(x))
        # mu = self.mu(self.flatten(x))
        # var = self.var(self.flatten(x))
        mu = self.mu(r_data)
        var = self.var(r_data)
        return mu, var

and effectively bypassing the flatten and keeping the dimensions the error disappears.
I do not really understand how the flatten layer acts in-place however…

Hi Martin!

I haven’t looked at your code in detail, but the likely cause is as follows:

opt_enc.step() performs inplace modifications on the model parameters it is optimizing.

But you then perform a second .backward() on the graph that you retained and modified
inplace, hence the error.

It appears that opt_dec.step() does not optimize any of the encoder parameters. If this
is the case, then the encoder portion of the graph you backward through a second time in
dec_loss.backward() need not be in the decoder graph.

You are already computing dis_x_rec twice – the first time detached from the encoder
graph. Would it be possible to keep the detached version (instead of discarding it) and
use it to compute gen_loss and hence dec_loss? Then dec_loss.backward() won’t
re-traverse the inplace-modified encoder graph.

Of course, you should do some debugging to make sure that this is actually what is happening
(and that there aren’t similar errors that would have shown up after the the first error is raised).

You can find a discussion of how to debug and fix inplace-modification errors in this post.

I always recommend understanding what is going on and addressing it directly, but pytorch
does offer a sweep-inplace-modification-errors-under-the-rug context manager.

In reference to your second post:

creates a new tensor that doesn’t depend on x (and doesn’t carry requires_grad = True).

So when you backpropagate through the encoder, you backpropagate through self.mu
and self.var, but not through self.conv. If just parameters in self.conv are the ones that
cause problems when modified inplace, you’ve avoided the problem by not backpropagating
through them.

Best.

K. Frank

Hi @KFrank,
I implemented some of the changes you proposed and got this training:

      # soft labels
      fake = torch.full_like(targets, 0.1).unsqueeze(1).float()
      real = torch.full_like(targets, 0.9).unsqueeze(1).float()
      mu, log_var = vaegan.encoder(imgs)
      eps = torch.randn_like(mu)
      std = torch.exp(0.5 * log_var)
      z = mu + eps * std
      # detach z 
      z_detached = z.detach()
      x_rec = vaegan.decoder(z)
      x_rec_detached = x_rec.detach() # for use in gan loss
      # ask for a fresh x_rec from the detached z for use in gen loss
      x_rec_detach_gen_only = vaegan.decoder(z_detached)
      z_p = torch.randn_like(z)
      x_p = vaegan.decoder(z_p)
      # run the forwards of the discriminator
      dis_x, dis_x_aux = vaegan.discriminator(imgs) # > for gan and mse
      dis_x_rec, _ = vaegan.discriminator(x_rec_detached) # > for gan
      _, dis_x_rec_aux = vaegan.discriminator(x_rec) # > for mse
      dis_x_rec_gen_only, _ = vaegan.discriminator(x_rec_detach_gen_only) # > for gen
      dis_x_p, _ = vaegan.discriminator(x_p) # > for gan
      # compute losses
      gan_loss = F.binary_cross_entropy(dis_x, real) + F.binary_cross_entropy(dis_x_rec, fake) + F.binary_cross_entropy(dis_x_p, fake)
      kld = -0.5 * torch.sum(1 + log_var - torch.pow(mu, 2) - torch.exp(log_var))
      mse = F.mse_loss(dis_x_aux, dis_x_rec_aux)
      enc_loss = kld + mse
      gen_loss = F.binary_cross_entropy(dis_x_rec_gen_only, real)
      dec_loss = gamma * mse + gen_loss
      # zero grads and train
      opt_enc.zero_grad()
      opt_dec.zero_grad()
      opt_dis.zero_grad()
      enc_loss.backward(retain_graph=True)
      gen_loss.backward()
      gan_loss.backward()
      opt_enc.step()
      opt_dec.step()
      opt_dis.step()

which works fine thank you!
I took some time to create a flow diagram of what was expected below. Dotted tensors are leaves and colors show the gradient flows.


In the end, the only things I do not like are the multiple forward passes on the discriminator and the retain_graph=True remaining, but their use seems justified here.
Tell me if you think there is a cleaner or more principled solution you think of.
Best regards,
Martin

Hi Martin!

I wouldn’t worry about retain_graph = True – it has legitimate use cases and looks
reasonable here. (The problem is that it is sometimes used to hide a substantive underlying
issue. but it’s often appropriate with GANs.)

I’m not commenting on your architecture in particular – just making some general comments
about GAN training:

When training GANs, duplicating the forward pass is often the right – or at least the most
practical – thing to do. But there are a couple of ways to avoid it.

Let’s say you have a generator feeding a discriminator and use something like
BCEWithLogitsLoss as your loss criterion. When the discriminator get the right answer
(so the discriminator is doing well, but the generator is doing poorly), the loss is low.
You can backpropagate this loss to train the discriminator. But if you insert a custom
layer where the output of the generator is passed to the discriminator that flips the sign
of the gradient during backpropagation, you can then use the same forward / backward
pass also to train the generator.

(Note, although I think this approach is valid, I think that it’s not optimal. This is because
flipping the sign of the gradient is not mathematically the same as flipping the label used
in BCEWithLogitsLoss, so to speak, from “fake” to “real”. I don’t have evidence for this,
but I would expect that backpropagating the flipped label – at the cost of running a second
backward pass – will train the generator more effectively than backpropagating the flipped
sign – which would only require a single backward pass.)

The other thing you can do is perform a forward pass for the generator, make the output
of the generator a leaf variable, perform the forward pass for the discriminator, using the
output of the generator as a leaf-variable input. You can then call .backward() for the
discriminator and the generator separately and have more granular control over the two
“half” backward passes, for example how (and whether) they populate various .grads.

Best.

K. Frank

Thanks @KFrank.
Just to be sure I understand:

This is essentially what I did with detaching tensors and using an optimizer for each part of the whole VAEGAN, or is there another way to do it?
Also:

You could but that would probably be best for another topic on data generation :smile:
Anyway thanks again,
Martin

Hi Martin!

Let me add a little more detail to my previous answer.

Here you run two forward passes through decoder with the same data – z and a
detached version of z. (Your third forward pass through decoder is not duplicative.)

And here you run three forward passes through discriminator with the same data – x_rec
and two detached versions of x_rec. (Your other two discriminator forward passes are
not duplicative.)

If you take more granular control of your backward passes, you can avoid the redundant
forward passes.

Logically, you need only two decoder forward passes – one for the encoded feature
vector, z, and one for the random feature vector, z_p. Similarly, you only need three
forward passes through discriminator – one each for actual image, one for the image
reconstructed from z, and one for the fully fake image reconstructed from z_p.

By using .detach() similarly to what you have done, you can break your backward passes
up into separate pieces for your three submodels. Then by using autograd.backward()
when you want to accumulate gradients into a submodel’s parameters and by alternatively
using autograd.grad() when you want to backpropagate through a submodel without
accumulating gradients, you can properly isolate the desired gradient computations without
running redundant forward passes.

Here is some (untested) pseudocode that illustrates this approach with a slightly simplified
version of what I understand to be your use case:

# enc -- encoder: transforms an image to a feature vector
# dec -- decoder: transforms a feature vector to a reconstructed image
# dsc -- discriminator: predicts whether an image is real or is generated from a feature vector (whether encoded or random)

# train enc and dec (which together form the generator) to  produce a reconstructed image that looks like the
#    original image as measured by a certain feature vector ("aux") internal to the discriminator
# train dec (but not enc) to produce a reconstructed image that looks real to the discriminator
# train dsc to tell a generated image from a real image

# imag:        a real input image

# zenc:        feature vector encoded from imag
# zrnd:        a random, fake feature vector

# irec:        image reconstructed from encoded feature vector
# irnd:        fully fake image decoded from random feature vector

# imag_prd:    discriminator prediction for imag
# irec_prd:    discriminator prediction for irec
# irnd_prd:    discriminator prediction for irnd

# imag_aux:    internal discriminator features produced along with imag_prd
# irec_aux:    internal discriminator features produced along with irec_prd

# labl_img:    label for real image (0.9 for soft label)
# labl_rec     label for reconstructed image -- from both encoded and random feature vector (0.1 for soft label)

dsc_fn = BCEWithLogitsLoss()              # loss function for predictions of dsc

zenc = enc (imag)                         # feature vector from real image -- only forward pass through enc
zenc_d = zenc.detach().requires_grad_()   # split computation graph into controllable pieces

zrnd = torch.randn (z_dim)                # fake random feature vector

irec = dec (zenc_d)                       # reconstructed image -- first forward pass through dec
irnd = dec (zrnd)                         # random feature vector reconstruction -- second forward pass through dec

irec_d = irec.detach().requires_grad_()   # split computation graph into controllable pieces
irnd_d = irnd.detach().requires_grad_()   # split computation graph into controllable pieces

imag_prd, imag_aux = dsc (imag)           # discriminator prediction and features for real image -- first dsc forward pass
irec_prd, irec_aux = dsc (irec_d)         # prediction and features for reconstructed image from encoded vector -- second forward
irnd_prd, _        = dsc (irnd_d)         # prediction for reconstructed image from random vector -- third forward

dsc_loss = dsc_fn (imag_prd, labl_img) + dsc_fn (irec_prd, labl_rec) + dsc_fn (irnd_prd, labl_rec)
dsc_loss.backward (retain_graph = True)   # first dsc backward pass -- accumulates gradients into dsc, stops at irec_d and irnd_d

dec_loss = dsc_fn (irec_prd, labl_img) + dsc_fn (irnd_prd, labl_img)

# second dsc backward pass -- stops at irec_d and irnd_d, computing their gradients, does not accumulate gradients into dsc
irec_grd, irnd_grd = torch.autograd.grad (dec_loss, (irec_d, irnd_d), retain_graph = True)
# first dec backward pass -- stops at zenc_d, accumulates gradients into dec (and zenc_d, which will be cleared out and not used)
torch.autograd.backward ((irec, irnd), (irec_grd, irnd_grd), retain_graph = True)

gen_loss_loss = F.mse_loss (imag_aux, irec_aux)         # loss for both enc and dec

# third dsc backward pass -- stops at irec_d, but stores its gradient, does not accumulate gradients into dsc
irec_grd_gen = torch.autograd.grad (gen_loss, irec_d)   # (called irec_grd_gen just to indicate that it's not irec_grd)

# second dec backward pass -- accumulates gradients into dec, stops at zenc_d, computing its dec_loss gradient
zenc_d.grad = None                                      # first clear out zenc_d.grad from dec_loss
torch.autograd.backward (irec, irec_grd_gen)

# only enc backward pass -- accumulates gradients into enc
torch.autograd.backward (zenc, zenc.grad)

# set various variables to None so that remnant pieces of computation graphs can be freed
dsc_loss = None
dec_loss = None
gen_loss = None
imag_prd = imag_aux = None
irec_prd = irec_aux = None
irnd_prd = _        = None
irec = None
irnd = None
zenc = None

In terms of backward passes, logically there are three loss functions that need to be passed
back through what my code calls dsc (without being mixed together). Two of these need
to be passed back through dec, of which only one needs then be backpropagated through
enc. So I think in terms of both forward and backward passes, the example code is doing
the logical minimum.

(As an aside, you could combine, say, zenc_d and zrnd into a batch – or a larger batch,
if you are already using batches – and perform just a single batched forward pass through
dec. But this wouldn’t be performing any less computation than the two separate forward
passes, although the batched forward pass could well end up making more efficient use of
the floating-point pipelines.)

Best.

K. Frank

Hello @KFrank,
Thank you for the detailed answer. For completeness sake, here is the updated code with your proposed modifications. It works fine and I added some of my notes to make sure I understood your code.

# soft labels
fake = torch.full_like(targets, 0.1).unsqueeze(1).float()
real = torch.full_like(targets, 0.9).unsqueeze(1).float()

# encoder passes
mu, log_var = vaegan.encoder(imgs)
eps = torch.randn_like(mu)
std = torch.exp(0.5 * log_var)
z_enc = mu + eps * std 
z_enc_detached = z_enc.detach().requires_grad_()
z_p = torch.randn_like(z_enc_detached)

# decoder passes
img_rec = vaegan.decoder(z_enc_detached)
img_rnd = vaegan.decoder(z_p)
# detach strams
img_rec_detach = img_rec.detach().requires_grad_()
img_rnd_detach = img_rnd.detach().requires_grad_() 
# img_rnd is detached because we do not the gradient to flow down to z_p 

# discriminator passes
img_pred, img_aux = vaegan.discriminator(imgs)
img_rec_pred, img_rec_aux = vaegan.discriminator(img_rec_detach)
img_rnd_pred, _ = vaegan.discriminator(img_rnd_detach)

# first backward normal pass through discriminator
# stops at img_rec_detached, img_rnd_detach
disc_loss = bce_loss(img_pred, real) + bce_loss(img_rec_pred, fake) + bce_loss(img_rnd_pred, fake)
disc_loss.backward(retain_graph=True)

# second backward pass through discriminator - this one does accumulate gradients
# stops at img_rec_detached
gen_loss = bce_loss(img_rec_pred, real)
img_rec_detach_grads = torch.autograd.grad(gen_loss, img_rec_detach, retain_graph=True)
# then first backward pass through decoder, accumulating gradients
# stops at z_enc_detach
# note: the first argument of this backward is the starting point of the backprop
torch.autograd.backward(img_rec, img_rec_detach_grads, retain_graph=True)

# third backward pass through discriminator - does not accumulate gradients
# stops at img_rec_detached
mse = F.mse_loss(img_aux, img_rec_aux)
img_rec_detach_grads_dec = torch.autograd.grad(gamma * mse, img_rec_detach)
# relink the gradients to the original tensor z_enc that was detached and backprop
# remove grads that were accumulated in z_enc_detached 
# is it really necessary to zero the grads here though?
# second backward pass through decoder and accumulates gradients
# stops at z_enc_detached
z_enc_detached.grad = None 
torch.autograd.backward(img_rec, img_rec_detach_grads_dec) 

# first backward through encoder, accumulating gradients and stopping at imgs
# use z_enc_detached grad to backprop
torch.autograd.backward(z_enc, z_enc_detached.grad, retain_graph=True) 

# we still need to compute kld and backpropagate it (normally)
kld = -0.5 * torch.sum(1 + log_var - torch.pow(mu, 2) - torch.exp(log_var))
kld.backward()

dec_loss = gamma * mse + gen_loss
enc_loss = mse + kld

# zero out gradients and so on
disc_loss = gen_loss = mse = kld = dec_loss = enc_loss = None
img_rec_detach_grads = img_rec_detach_grads_dec = None
mu = log_var = z_enc = z_enc_detached = eps = std = z_p = None
img_rec = img_rnd = img_rec_detach = img_rnd_detach = None
img_pred = img_aux = img_rec_pred = img_rec_aux = img_rnd_pred = _ = None

I will close the topic and highlight the answer.
Regards,
Martin