Error-Pytorch-BiGAn

Hi, I implement the BiGAN code and I face this error:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1, 1024, 1, 1]] is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Here is the code:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Discriminator(nn.Module):
def init(self, z_dim=32, wasserstein=False):
super(Discriminator, self).init()
self.wass = wasserstein

    # Inference over x
    self.conv1x = nn.Conv2d(3, 32, 5, stride=1, bias=False)
    self.conv2x = nn.Conv2d(32, 64, 4, stride=2, bias=False)
    self.bn2x = nn.BatchNorm2d(64)
    self.conv3x = nn.Conv2d(64, 128, 4, stride=1, bias=False)
    self.bn3x = nn.BatchNorm2d(128)
    self.conv4x = nn.Conv2d(128, 256, 4, stride=2, bias=False)
    self.bn4x = nn.BatchNorm2d(256)
    self.conv5x = nn.Conv2d(256, 512, 4, stride=1, bias=False)
    self.bn5x = nn.BatchNorm2d(512)

    # Inference over z
    self.conv1z = nn.Conv2d(z_dim, 512, 1, stride=1, bias=False)
    self.conv2z = nn.Conv2d(512, 512, 1, stride=1, bias=False)

    # Joint inference
    self.conv1xz = nn.Conv2d(1024, 1024, 1, stride=1, bias=False)
    self.conv2xz = nn.Conv2d(1024, 1024, 1, stride=1, bias=False)
    self.conv3xz = nn.Conv2d(1024, 1, 1, stride=1, bias=False)

def inf_x(self, x):
    x = F.dropout2d(F.leaky_relu(self.conv1x(x), negative_slope=0.1), 0.2, inplace=False)
    x = F.dropout2d(F.leaky_relu(self.bn2x(self.conv2x(x)), negative_slope=0.1), 0.2,inplace=False)
    x = F.dropout2d(F.leaky_relu(self.bn3x(self.conv3x(x)), negative_slope=0.1), 0.2,inplace=False)
    x = F.dropout2d(F.leaky_relu(self.bn4x(self.conv4x(x)), negative_slope=0.1), 0.2,inplace=False)
    x = F.dropout2d(F.leaky_relu(self.bn5x(self.conv5x(x)), negative_slope=0.1), 0.2,inplace=False)
    return x

def inf_z(self, z):
    z = F.dropout2d(F.leaky_relu(self.conv1z(z), negative_slope=0.1), 0.2, inplace=False)
    z = F.dropout2d(F.leaky_relu(self.conv2z(z), negative_slope=0.1), 0.2,inplace=False)
    return z

def inf_xz(self, xz):
    xz = F.dropout(F.leaky_relu(self.conv1xz(xz), negative_slope=0.1), 0.2,inplace=False)
    xz = F.dropout(F.leaky_relu(self.conv2xz(xz), negative_slope=0.1), 0.2,inplace=False)
    return self.conv3xz(xz)

def forward(self, x, z):
    x = self.inf_x(x)
    z = self.inf_z(z)
    xz = torch.cat((x,z), dim=1)
    out = self.inf_xz(xz)
    if self.wass:
        return out
    else:
        return torch.sigmoid(out)

class Generator(nn.Module):
def init(self, z_dim=32):
super(Generator, self).init()
self.z_dim = z_dim

    self.output_bias = nn.Parameter(torch.zeros(3, 32, 32), requires_grad=True)
    self.deconv1 = nn.ConvTranspose2d(z_dim, 256, 4, stride=1, bias=False)
    self.bn1 = nn.BatchNorm2d(256)
    self.deconv2 = nn.ConvTranspose2d(256, 128, 4, stride=2, bias=False)
    self.bn2 = nn.BatchNorm2d(128)
    self.deconv3 = nn.ConvTranspose2d(128, 64, 4, stride=1, bias=False)
    self.bn3 = nn.BatchNorm2d(64)
    self.deconv4 = nn.ConvTranspose2d(64, 32, 4, stride=2, bias=False)
    self.bn4 = nn.BatchNorm2d(32)
    self.deconv5 = nn.ConvTranspose2d(32, 32, 5, stride=1, bias=False)
    self.bn5 = nn.BatchNorm2d(32)
    self.deconv6 = nn.Conv2d(32, 3, 1, stride=1, bias=True)

def forward(self, z):
    z = F.leaky_relu(self.bn1(self.deconv1(z)), negative_slope=0.1,inplace=False)
    z = F.leaky_relu(self.bn2(self.deconv2(z)), negative_slope=0.1,inplace=False)
    z = F.leaky_relu(self.bn3(self.deconv3(z)), negative_slope=0.1,inplace=False)
    z = F.leaky_relu(self.bn4(self.deconv4(z)), negative_slope=0.1,inplace=False)
    z = F.leaky_relu(self.bn5(self.deconv5(z)), negative_slope=0.1,inplace=False)
    return torch.sigmoid(self.deconv6(z) + self.output_bias)

class Encoder(nn.Module):
def init(self, z_dim=32):
super(Encoder, self).init()
self.z_dim = z_dim
self.conv1 = nn.Conv2d(3, 32, 5, stride=1, bias=False)
self.bn1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, 4, stride=2, bias=False)
self.bn2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(64, 128, 4, stride=1, bias=False)
self.bn3 = nn.BatchNorm2d(128)
self.conv4 = nn.Conv2d(128, 256, 4, stride=2, bias=False)
self.bn4 = nn.BatchNorm2d(256)
self.conv5 = nn.Conv2d(256, 512, 4, stride=1, bias=False)
self.bn5 = nn.BatchNorm2d(512)
self.conv6 = nn.Conv2d(512, 512, 1, stride=1, bias=False)
self.bn6 = nn.BatchNorm2d(512)
self.bn7 = nn.Conv2d(512, z_dim*2, 1, stride=1, bias=True)

def reparameterize(self, z):
    z = z.view(z.size(0), -1)
    mu, log_sigma = z[:, :self.z_dim], z[:, self.z_dim:]
    std = torch.exp(log_sigma)
    eps = torch.randn_like(std)
    return mu + eps * std

def forward(self, x):
    x = F.leaky_relu(self.bn1(self.conv1(x)), negative_slope=0.1,inplace=False)
    x = F.leaky_relu(self.bn2(self.conv2(x)), negative_slope=0.1,inplace=False)
    x = F.leaky_relu(self.bn3(self.conv3(x)), negative_slope=0.1,inplace=False)
    x = F.leaky_relu(self.bn4(self.conv4(x)), negative_slope=0.1,inplace=False)
    x = F.leaky_relu(self.bn5(self.conv5(x)), negative_slope=0.1,inplace=False)
    x = F.leaky_relu(self.bn6(self.conv6(x)), negative_slope=0.1,inplace=False)
    z = self.reparameterize(self.conv6(x))
    return z.view(x.size(0), self.z_dim, 1, 1)

import torch
import torch.nn as nn
from torch import optim
from torch.autograd import Variable
import torch.nn.functional as F
import torchvision.utils as vutils

import numpy as np
from barbar import Bar

class TrainerBiGAN:
def init(self, args, data, device):
self.args = args
self.train_loader = data
self.device = device

def train(self):
    """Training the BiGAN"""
    self.G = Generator(self.args.latent_dim).to(self.device)
    self.E = Encoder(self.args.latent_dim).to(self.device)
    self.D = Discriminator(self.args.latent_dim, self.args.wasserstein).to(self.device)

    self.G.apply(weights_init_normal)
    self.E.apply(weights_init_normal)
    self.D.apply(weights_init_normal)

    if self.args.wasserstein:
        optimizer_ge = optim.RMSprop(list(self.G.parameters()) +
                                     list(self.E.parameters()), lr=self.args.lr_rmsprop)
        optimizer_d = optim.RMSprop(self.D.parameters(), lr=self.args.lr_rmsprop)
    else:
        optimizer_ge = optim.Adam(list(self.G.parameters()) +
                                  list(self.E.parameters()), lr=self.args.lr_adam)
        optimizer_d = optim.Adam(self.D.parameters(), lr=self.args.lr_adam)

    fixed_z = Variable(torch.randn((16, self.args.latent_dim, 1, 1)),
                       requires_grad=False).to(self.device)
    criterion = nn.BCELoss()
    for epoch in range(self.args.num_epochs+1):
        ge_losses = 0
        d_losses = 0
        for x, _ in Bar(self.train_loader):
            #Defining labels
            y_true = Variable(torch.ones((x.size(0), 1)).to(self.device))
            y_fake = Variable(torch.zeros((x.size(0), 1)).to(self.device))

            #Noise for improving training.
            noise1 = Variable(torch.Tensor(x.size()).normal_(0, 
                              0.1 * (self.args.num_epochs - epoch) / self.args.num_epochs),
                              requires_grad=False).to(self.device)
            noise2 = Variable(torch.Tensor(x.size()).normal_(0, 
                              0.1 * (self.args.num_epochs - epoch) / self.args.num_epochs),
                              requires_grad=False).to(self.device)

            #Cleaning gradients.
            optimizer_d.zero_grad()
            optimizer_ge.zero_grad()

            #Generator:
            z_fake = Variable(torch.randn((x.size(0), self.args.latent_dim, 1, 1)).to(self.device),
                              requires_grad=False)
            x_fake = self.G(z_fake)

            #Encoder:
            x_true = x.float().to(self.device)
            z_true = self.E(x_true)

            #Discriminator
            out_true = self.D(x_true + noise1, z_true)
            out_fake = self.D(x_fake + noise2, z_fake)

            #Losses
            if self.args.wasserstein:
                loss_d = - torch.mean(out_true) + torch.mean(out_fake)
                loss_ge = - torch.mean(out_fake) + torch.mean(out_true)
            else:
                # 27/01/21
                y_true = y_true.unsqueeze(1)
                y_true = y_true.unsqueeze(1)
                y_fake=y_fake.unsqueeze(1)
                y_fake=y_fake.unsqueeze(1)
                loss_d = criterion(out_true, y_true) + criterion(out_fake, y_fake)
              
      
                loss_ge =(criterion(out_fake, y_true) + criterion(out_true, y_fake)).clone()

            #Computing gradients and backpropagate.
            loss_d.backward(retain_graph=True)
            optimizer_d.step()
            #with torch.autograd.set_detect_anomaly(True):
            torch.autograd.set_detect_anomaly(True)
            loss_ge.backward(retain_graph=True)
            optimizer_ge.step()


            if self.args.wasserstein:
                for p in self.D.parameters():
                    p.data.clamp_(-self.args.clamp, self.args.clamp)
            torch.autograd.backward(self, gradient, retain_graph, create_graph)
            ge_losses += loss_ge.item()
            d_losses += loss_d.item()

        if epoch % 50 == 0:
            vutils.save_image(self.G(fixed_z).data, './images/{}_fake.png'.format(epoch))

        print("Training... Epoch: {}, Discrimiantor Loss: {:.3f}, Generator Loss: {:.3f}".format(
            epoch, d_losses/len(self.train_loader), ge_losses/len(self.train_loader)
        ))

import torch

def weights_init_normal(m):
classname = m.class.name
if classname.find(“Conv”) != -1 and classname != ‘Conv’:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
if m.bias is not None:
m.bias.data.fill_(0)
elif classname.find(“Linear”) != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
if m.bias is not None:
m.bias.data.fill_(0)
elif classname.find(‘BatchNorm’) != -1:
m.weight.data.normal_(1.0, 0.01)
if m.bias is not None:
m.bias.data.fill_(0)
import torch
import numpy as np
from torchvision import datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

def get_cifar10(args, data_dir=’./data/cifar/’):
“”“Returning cifar dataloder.”""
transform = transforms.Compose([transforms.Resize(32), #3x32x32 images.
transforms.ToTensor()])
data = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
dataloader = DataLoader(data, batch_size=args.batch_size, shuffle=True)
return dataloader
import numpy as np
import argparse
import torch

if name == ‘main’:
parser = argparse.ArgumentParser()
parser.add_argument("–num_epochs", type=int, default=200,
help=“number of epochs”)
parser.add_argument(’–lr_adam’, type=float, default=1e-4,
help=‘learning rate’)
parser.add_argument(’–lr_rmsprop’, type=float, default=1e-4,
help=‘learning rate RMSprop if WGAN is True.’)
parser.add_argument("–batch_size", type=int, default=128,
help=“Batch size”)
parser.add_argument(’–latent_dim’, type=int, default=256,
help=‘Dimension of the latent variable z’)
parser.add_argument(’–wasserstein’, type=bool, default=False,
help=‘If WGAN.’)
parser.add_argument(’–clamp’, type=float, default=1e-2,
help=‘Clipping gradients for WGAN.’)
#parsing arguments.
args = parser.parse_args()

#check if cuda is available.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

data = get_cifar10(args)

bigan = TrainerBiGAN(args, data, device)
bigan.train()

Someone coul help me please?

Could you please check this post, which describes common errors in GAN training, where stale activations are used and which might raise this error?

PS: you can post code snippets by wrapping them into three backticks ```, which makes debugging easier. :wink:

Hi, the problem comes from the computation of gradients and backpropagation:
loss_d.backward(retain_graph=True)
optimizer_d.step()

loss_ge.backward(retain_graph=True)
optimizer_ge.step()

When I remove these two lines:
loss_ge.backward(retain_graph=True)
optimizer_ge.step()
the code works but sure it is incorrect. So I understand That the problem come from here. Do you have any suggestions please?
Thank you,

The link in my previous post explains this behavior in more detail.
The short version is:

  • loss_d and loss_ge were calculated in some code before your posted code snippet
  • optimizer_d.step() changes the parameters of the discriminator
  • loss_ge.backward() calculates the gradients by backpropagating through the discriminator into the generator
  • since the discriminator was already updated (changed parameters), the backward pass fails

To avoid these issues, you could e.g. update the discriminator before, perform the forward pass and loss calculation for the generator update, and calculate the gradients for the generator as the last step before calling optimizer_ge.step().

Thank you for you answer, I maid this change and it works. Could you tell me please if it’s correct?

““loss_d.backward(retain_graph=True)
loss_ge.backward()
optimizer_d.step()
optimizer_ge.step()””

Thank you,

I don’t think this is correct, since the discriminator would now get the gradients for the generator as well.
Have a look at the DCGAN example to see how both models are updated.