StyleGAN2 retain_graph=False RuntimeError

Hi, when I start to train a simple version StyleGAN2 in PyTorch 1.0.1.

Since I do not want to use retain_graph = True, A Runtime Error Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time. raised here.

The code of mine is below, really appreciate for your help, thank you!

...
# G & D definition.
G = G_stylegan2(fmap_base=opts.fmap_base,
                    resolution=opts.resolution,
                    mapping_layers=opts.mapping_layers,
                    opts=opts,
                    return_dlatents=opts.return_latents)
D = D_stylegan2(fmap_base=opts.fmap_base,
                    resolution=opts.resolution,
                    structure='resnet')
...
Loss_D_list = [0.0]
Loss_G_list = [0.0]
softplus = torch.nn.Softplus()
...
# Training.
for epoch in range(0, 100):
     bar = tqdm(loader)
     for i, (real_img,) in enumerate(bar):

            real_img = real_img.to(opts.device)
            latents = torch.randn([real_img.size(0), 512]).to(opts.device)

            # =======================================================================================================
            #   (1) Update D network: D_logistic_r1(default)
            # =======================================================================================================
            # Compute adversarial loss toward discriminator
            real_logit = D(real_img)
            fake_img = G(latents)
            fake_logit = D(fake_img.detach())

            d_loss = softplus(fake_logit).mean()
            d_loss = d_loss + softplus(-real_logit).mean()

            # original
            r1_penalty = D_logistic_r1(real_img.detach(), D)
            d_loss = d_loss + r1_penalty
            # lite
            # d_loss = d_loss.mean()

            loss_D_list.append(d_loss.item())

            # Update discriminator
            optim_D.zero_grad()
            d_loss.backward()
            optim_D.step()


            # =======================================================================================================
            #   (2) Update G network: G_logistic_ns_pathreg(default)
            # =======================================================================================================
            if i % 3 == 0:
                G.zero_grad()
                fake_scores_out = D(fake_img)
                g_loss = softplus(-fake_scores_out).mean()

                loss_G_list.append(g_loss.item())

                # Update generator
                # Note: g_loss.backward(retain_graph=True) is ok, but this version is not I want!
                g_loss.backward()
                optim_G.step()

I’m not sure why you are seeing this error. I tried to reproduce it using your code snippet (and removing some steps) which works fine:

D = nn.Linear(1, 1)
G = nn.Linear(1, 1)

optim_D = torch.optim.SGD(D.parameters(), lr=1e-3)
optim_G = torch.optim.SGD(G.parameters(), lr=1e-3)


real_img = torch.randn(1, 1)
latents = torch.randn(1, 1)
# =======================================================================================================
#   (1) Update D network: D_logistic_r1(default)
# =======================================================================================================
# Compute adversarial loss toward discriminator
real_logit = D(real_img)
fake_img = G(latents)
fake_logit = D(fake_img.detach())

d_loss = fake_logit.mean()
d_loss = d_loss + real_logit.mean()

# Update discriminator
optim_D.zero_grad()
d_loss.backward()
optim_D.step()

# =======================================================================================================
#   (2) Update G network: G_logistic_ns_pathreg(default)
# =======================================================================================================
G.zero_grad()
fake_scores_out = D(fake_img)
g_loss = fake_scores_out.mean()

# Update generator
# Note: g_loss.backward(retain_graph=True) is ok, but this version is not I want!
g_loss.backward()
optim_G.step()

That being said, the general training routing works fine.
However, the penalty calculation might still produce this error (which isn’t defined unfortunately, so I had to remove it for my code snippet).

1 Like

Thank you very much. I think it may be some sneaky point in definition of my G? Because the error only occur in g_loss.backward()?

My Generator is like this (re-implement StyleGAN2 in PyTorch 1.x):

class G_stylegan2(nn.Module):
    def __init__(self,
                 opts,
                 return_dlatents=True,
                 fmap_base=8 << 10,  # stylegan1 8192 (8 << 10), stylegan2 16384 (16 << 10)
                 num_channels=3,  # Number of output color channels.
                 mapping_fmaps=512,
                 dlatent_size=512,  # Disentangled latent (W) dimensionality.
                 resolution=1024,  # Output resolution.
                 mapping_layers=8,  # Number of mapping layers.
                 randomize_noise=True,
                 fmap_decay=1.0,  # log2 feature map reduction when doubling the resolution.
                 fmap_min=1,  # Minimum number of feature maps in any layer.
                 fmap_max=512,  # Maximum number of feature maps in any layer.
                 architecture='skip',  # Architecture: 'orig', 'skip'.
                 act='lrelu',  # Activation function: 'linear', 'lrelu'.
                 lrmul=0.01,  # Learning rate multiplier for the mapping layers.
                 gain=1,  # original gain in tensorflow.
                 truncation_psi = 0.7,  # Style strength multiplier for the truncation trick. None = disable.
                 truncation_cutoff = 8,  # Number of layers for which to apply the truncation trick. None = disable.
                 ):
        super().__init__()
        assert architecture in ['orig', 'skip']

        self.return_dlatents = return_dlatents
        self.num_channels = num_channels

        self.g_mapping = G_mapping(mapping_fmaps=mapping_fmaps,
                                   dlatent_size=dlatent_size,
                                   resolution=resolution,
                                   mapping_layers=mapping_layers,
                                   lrmul=lrmul,
                                   gain=gain)

        self.g_synthesis = G_synthesis_stylegan2(resolution=resolution,
                                                 architecture=architecture,
                                                 randomize_noise=randomize_noise,
                                                 fmap_base=fmap_base,
                                                 fmap_min=fmap_min,
                                                 fmap_max=fmap_max,
                                                 fmap_decay=fmap_decay,
                                                 act=act,
                                                 opts=opts)

        self.truncation_cutoff = truncation_cutoff
        self.truncation_psi = truncation_psi

    def forward(self, x):
        dlatents1 = self.g_mapping(x)
        out = self.g_synthesis(dlatents1)

        if self.return_dlatents:
            return out, dlatents1
        else:
            return out

It might be related to the generator.
Do you still see this issue, if you remove the penalty calculation?
Also, would it be possible to post the complete generator code or is it still unpublished work in progress?
In that case, could you try to come up with a small code snippet to reproduce this issue?

1 Like

Thanks for you kindness ! :smiley: I certainly like to post my code(a little big somehow)! I remove the penalty calculation and it’s still the same error occurs.

The definition of the Generator is in stylegan2.py, it has two main class: G_mapping and G_synthesis_stylegan2. Thank you very much and if you are interest in fullfill this, it would be great!

Hi, I found in BiasAdd module, if I do not wrap self.bias with nn.Parameter, then I do not need to worry about retain_graph, but if I wrap it with nn.Parameer, it obligatory requires me to add retain_graph=True.

Do you know why? Thanks a lot!

class BiasAdd(nn.Module):

    def __init__(self,
                 channels,
                 opts,
                 act='linear', alpha=None, gain=None, lrmul=1):
        """
            BiasAdd
        """
        super(BiasAdd, self).__init__()

        self.opts = opts
        # fixme:
        # self.bias = nn.Parameter(torch.zeros(channels, 1, 1) * lrmul).to(opts.device) # error!
        self.bias = (torch.zeros(channels, 1, 1) * lrmul).to(opts.device) # ok

        self.act = act
        self.alpha = alpha if alpha is not None else 0.2
        self.gain = gain if gain is not None else 1.0

    def forward(self, x):
        # Pass Add bias.
        # x += self.bias (if self.bias wrap with nn.Parameter) ok!
        x = x + self.bias

        # Evaluate activation function.
        if self.act == "linear":
            pass
        elif self.act == 'lrelu':
            x = F.leaky_relu(x, self.alpha, inplace=True)
            x = x * np.sqrt(2)  # original repo def_gain=np.sqrt(2).

        # Scale by gain.
        if self.gain != 1:
            x = x * self.gain

        return x

It seems one BiasAdd module is reused somewhere in the code, but I cannot spot the line.

This should be unrelated to your current issue, but I would recommend to setup the bias tensor and wrap it into nn.Parameter as the final step to create a leaf variable.
Currently you are creating a non-leaf variable by calling .to() on the parameter.

1 Like

Really appreciate for your help!

Did the change in creating the parameter solve the issue, since you’ve marked the post as the solution?

Now I let BiasAdd like below and it works well without retain_graph = True.

class BiasAdd(nn.Module):

    def __init__(self,
                 channels,
                 opts,
                 act='linear', alpha=None, gain=None, lrmul=1):
        """
            BiasAdd
        """
        super(BiasAdd, self).__init__()

        self.opts = opts
        self.bias = torch.nn.Parameter((torch.zeros(channels, 1, 1) * lrmul))

    def forward(self, x):
       x += self.bias
       ...
       return x
1 Like