How to correctly apply spectral normalization for WGAN?

I am using aladdinperssons code for WGAN-GP: Machine-Learning-Collection/ML/Pytorch/GANs/4. WGAN-GP/train.py at master · aladdinpersson/Machine-Learning-Collection · GitHub

and removed the GP part and instead applied spectral normalization to both the critic and the generator as follows:

critic_fake = critic(fake).reshape(-1)
gp = gradient_penalty(critic, real, fake, device=device)
loss_critic = (
    -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
)
critic.zero_grad()

becomes:

critic_fake = critic(fake).reshape(-1)
loss_critic = (
    -(torch.mean(critic_real) - torch.mean(critic_fake))
)
critic.zero_grad()

and:

# initialize gen and disc, note: discriminator should be called critic,
# according to WGAN paper (since it no longer outputs between [0, 1])
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device)
initialize_weights(gen)
initialize_weights(critic)

becomes:

# initialize gen and disc, note: discriminator should be called critic,
# according to WGAN paper (since it no longer outputs between [0, 1])
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device)
initialize_weights(gen)
initialize_weights(critic)
apply_spectral_norm(gen)
apply_spectral_norm(critic)

Where apply_spectral_norm() is defined as:

def flatten_network(net):
    flattened = [flatten_network(children) for children in net.children()]
    res = [net]
    for c in flattened:
        res += c
    return res

def apply_spectral_norm(net):
    # Apply spectral normalization for conv layers
    for p in flatten_network(net):
        if isinstance(p, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
            nn.utils.spectral_norm(p)

I expected the GAN to converge in similar speed but this is what I see:

(blue is WGAN-GP and red is WGAN-SN)

Any tips on what I am missing?

bumping this because I still haven’t found a good configuration

Maaaybe I have found 2 improvements that work.

#1: I think possibly my way of applying spectral norm was incorrect…

Now, I try to apply it to all conv, linear, and normalization layers:

def apply_spectral_norm(net: torch.nn.Module, power=1):
    for m in net.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear, nn.BatchNorm2d, nn.BatchNorm1d, nn.InstanceNorm2d, nn.InstanceNorm1d)):
            nn.utils.parametrizations.spectral_norm(m, n_power_iterations=power)

This also uses the updated and non-deprecated way of adding spectral norm.

Maybe there is a better way to do this? Like check if “weight” exists in each module somehow.

#2: Now I only apply it to the discriminator:

# initialize gen and disc, note: discriminator should be called critic,
# according to WGAN paper (since it no longer outputs between [0, 1])
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device)
initialize_weights(gen)
initialize_weights(critic)
# apply_spectral_norm(gen)
apply_spectral_norm(critic)

This seems to work better than previously.

But trying this out on the minimal example stated above, it still does not compare at all:


If I’m not wrong, you’re supposed to apply spectral norm every iteration, and typically to only the critic. Also, increase the value of n_power_iterations if you haven’t, because it gives a better estimate of the spectral norm.

Typically the learnable layers like convolution layers and dense layers have weights in the critic, you can manually apply it to that instead of having a function to check. I don’t think you need to have any other type of layer in the critic network afaik, cause things like batch normalization don’t help much in the critic network.

Did you try running it for more time? My guess is that since you are dividing by the spectral norm, the weights become closer to zero → so smaller gradients are being backpropagated → so the learning in the previous layers and the generator are slower. Increase the learning rate in case of the spectral norm example or just wait a bit longer.