How to correctly apply spectral normalization for WGAN?

I am using aladdinperssons code for WGAN-GP: Machine-Learning-Collection/ML/Pytorch/GANs/4. WGAN-GP/train.py at master · aladdinpersson/Machine-Learning-Collection · GitHub

and removed the GP part and instead applied spectral normalization to both the critic and the generator as follows:

critic_fake = critic(fake).reshape(-1)
gp = gradient_penalty(critic, real, fake, device=device)
loss_critic = (
    -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
)
critic.zero_grad()

becomes:

critic_fake = critic(fake).reshape(-1)
loss_critic = (
    -(torch.mean(critic_real) - torch.mean(critic_fake))
)
critic.zero_grad()

and:

# initialize gen and disc, note: discriminator should be called critic,
# according to WGAN paper (since it no longer outputs between [0, 1])
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device)
initialize_weights(gen)
initialize_weights(critic)

becomes:

# initialize gen and disc, note: discriminator should be called critic,
# according to WGAN paper (since it no longer outputs between [0, 1])
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device)
initialize_weights(gen)
initialize_weights(critic)
apply_spectral_norm(gen)
apply_spectral_norm(critic)

Where apply_spectral_norm() is defined as:

def flatten_network(net):
    flattened = [flatten_network(children) for children in net.children()]
    res = [net]
    for c in flattened:
        res += c
    return res

def apply_spectral_norm(net):
    # Apply spectral normalization for conv layers
    for p in flatten_network(net):
        if isinstance(p, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
            nn.utils.spectral_norm(p)

I expected the GAN to converge in similar speed but this is what I see:

(blue is WGAN-GP and red is WGAN-SN)

Any tips on what I am missing?

bumping this because I still haven’t found a good configuration

Maaaybe I have found 2 improvements that work.

#1: I think possibly my way of applying spectral norm was incorrect…

Now, I try to apply it to all conv, linear, and normalization layers:

def apply_spectral_norm(net: torch.nn.Module, power=1):
    for m in net.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear, nn.BatchNorm2d, nn.BatchNorm1d, nn.InstanceNorm2d, nn.InstanceNorm1d)):
            nn.utils.parametrizations.spectral_norm(m, n_power_iterations=power)

This also uses the updated and non-deprecated way of adding spectral norm.

Maybe there is a better way to do this? Like check if “weight” exists in each module somehow.

#2: Now I only apply it to the discriminator:

# initialize gen and disc, note: discriminator should be called critic,
# according to WGAN paper (since it no longer outputs between [0, 1])
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device)
initialize_weights(gen)
initialize_weights(critic)
# apply_spectral_norm(gen)
apply_spectral_norm(critic)

This seems to work better than previously.

But trying this out on the minimal example stated above, it still does not compare at all: