W-GAN loss gets stuck

Hello,

i tried to implement a W-GAN but ran into a stubborn problem.

The loss for both the generator and the critic starts at 0, and slowly the generators loss rise to 1.5 and the critics loss falls to -2.8, and after that the losses stay very close to those values. I tried everything but couldnt get it fixed.

Here is the full code:

import torch
import torch.nn as nn
import torch.utils.data as d_utils
import torchvision
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter

transform = transforms.Compose([transforms.Resize((64,64)),transforms.ToTensor(),transforms.Normalize((0.5),(0.5))])
MNIST = torchvision.datasets.MNIST(".data/",transform=transform)

def weight_init(model):
    for module in model.modules():
        if isinstance(module,(nn.Conv2d,nn.ConvTranspose2d,nn.BatchNorm2d)):
            nn.init.normal_(module.weight.data,0.0,0.02)

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = self.create_block(128,1024,4,1,0)
        self.block2 = self.create_block(1024,512,4,2,1)
        self.block3 = self.create_block(512,256,4,2,1)
        self.block4 = self.create_block(256,128,4,2,1)
        self.block5 = nn.Sequential(nn.ConvTranspose2d(128,1,4,2,1),nn.BatchNorm2d(1),nn.Tanh())

    def create_block(self,in_f,out_f,kernel,stride,pad):
        deconv = nn.ConvTranspose2d(in_f,out_f,kernel,stride,pad)
        batch_norm = nn.BatchNorm2d(out_f)
        relu = nn.ReLU()
        return nn.Sequential(deconv,batch_norm,relu)

    def forward(self,x):
        x = x.view(-1,128,1,1)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        return x
    
class Critic(nn.Module):
    def __init__(self,leak_val):
        super().__init__()
        self.block1 = nn.Sequential(nn.Conv2d(1,128,4,2,1),nn.LeakyReLU(leak_val))
        self.block2 = self.create_block(128,256,4,2,1,leak_val)
        self.block3 = self.create_block(256,512,4,2,1,leak_val)
        self.block4 = self.create_block(512,1024,4,2,1,leak_val)
        self.block5 = nn.Sequential(nn.Conv2d(1024,1,4,1,0))
    
    def create_block(self,in_f,out_f,kernel,stride,pad,leak_val):
        conv = nn.Conv2d(in_f,out_f,kernel,stride,pad)
        batch_norm = nn.InstanceNorm2d(out_f,affine=True)
        l_relu = nn.LeakyReLU(leak_val)
        return nn.Sequential(conv,batch_norm,l_relu)
    
    def forward(self,x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        return x
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 64
MNIST = d_utils.DataLoader(MNIST,batch_size,True)

epochs = 500
lr_crit = 0.00005
lr_gen = 0.00005
weight_clip = 0.01
leak_val = 0.2

crit = Critic(leak_val).to(device)
gen = Generator().to(device)
weight_init(crit)
weight_init(gen)

crit_optim = torch.optim.RMSprop(crit.parameters(),lr_crit)
gen_optim = torch.optim.RMSprop(gen.parameters(),lr_gen)

noise_dim = 128
fixed_noise = torch.rand((batch_size,noise_dim)).to(device)

writer_fake = SummaryWriter(r"C:\Users\mehme\Logs\fake")
writer_real = SummaryWriter(r"C:\Users\mehme\Logs\real")

n_crit_step = 5

a = 0
c_rl,g_rl = 0,0
for epoch in range(1,epochs+1):
    for i,(real,_) in enumerate(MNIST):
        batch_size = real.shape[0]
        real = real.to(device).view(batch_size,1,64,64)
        
        for j in range(n_crit_step):
            noise = torch.rand((batch_size,noise_dim)).to(device)
        
            fake = gen(noise).to(device)
        
            real_pred = crit(real)
            fake_pred = crit(fake.detach())
            c_loss = torch.mean(fake_pred,0) - torch.mean(real_pred,0)

            crit_optim.zero_grad()
            c_loss.backward()
            crit_optim.step()

            for p in crit.parameters():
                p.data.clamp_(-weight_clip,weight_clip)

            c_rl += c_loss.detach()

        fake_pred = crit(fake)
        gen_loss = -torch.mean(fake_pred)

        g_rl += gen_loss.detach()

        gen_optim.zero_grad()
        gen_loss.backward()
        gen_optim.step()

        print(gen_loss,c_loss,i,len(MNIST))

        if not i%100:
            with torch.no_grad():
                gen.eval()
                fake = gen(fixed_noise).to(device)
                div_term = ((epoch-1)*len(MNIST))+i
                print(f"c_loss:{c_rl/(div_term*n_crit_step)},gen_loss:{g_rl/div_term}")
                gen.train()

                torch.save(gen.state_dict(),"generator.pth")
                torch.save(crit.state_dict(),"discriminator.pth")

                fake = gen(fixed_noise)
                data = real

                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image("Fake Images", img_grid_fake, global_step=a)
                writer_real.add_image("Real Images", img_grid_real, global_step=a)
                a += 1

I’ve checked if the generated images were actually getting better but they weren’t.