Hello,
i tried to implement a W-GAN but ran into a stubborn problem.
The loss for both the generator and the critic starts at 0, and slowly the generators loss rise to 1.5 and the critics loss falls to -2.8, and after that the losses stay very close to those values. I tried everything but couldnt get it fixed.
Here is the full code:
import torch
import torch.nn as nn
import torch.utils.data as d_utils
import torchvision
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
transform = transforms.Compose([transforms.Resize((64,64)),transforms.ToTensor(),transforms.Normalize((0.5),(0.5))])
MNIST = torchvision.datasets.MNIST(".data/",transform=transform)
def weight_init(model):
for module in model.modules():
if isinstance(module,(nn.Conv2d,nn.ConvTranspose2d,nn.BatchNorm2d)):
nn.init.normal_(module.weight.data,0.0,0.02)
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.block1 = self.create_block(128,1024,4,1,0)
self.block2 = self.create_block(1024,512,4,2,1)
self.block3 = self.create_block(512,256,4,2,1)
self.block4 = self.create_block(256,128,4,2,1)
self.block5 = nn.Sequential(nn.ConvTranspose2d(128,1,4,2,1),nn.BatchNorm2d(1),nn.Tanh())
def create_block(self,in_f,out_f,kernel,stride,pad):
deconv = nn.ConvTranspose2d(in_f,out_f,kernel,stride,pad)
batch_norm = nn.BatchNorm2d(out_f)
relu = nn.ReLU()
return nn.Sequential(deconv,batch_norm,relu)
def forward(self,x):
x = x.view(-1,128,1,1)
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = self.block5(x)
return x
class Critic(nn.Module):
def __init__(self,leak_val):
super().__init__()
self.block1 = nn.Sequential(nn.Conv2d(1,128,4,2,1),nn.LeakyReLU(leak_val))
self.block2 = self.create_block(128,256,4,2,1,leak_val)
self.block3 = self.create_block(256,512,4,2,1,leak_val)
self.block4 = self.create_block(512,1024,4,2,1,leak_val)
self.block5 = nn.Sequential(nn.Conv2d(1024,1,4,1,0))
def create_block(self,in_f,out_f,kernel,stride,pad,leak_val):
conv = nn.Conv2d(in_f,out_f,kernel,stride,pad)
batch_norm = nn.InstanceNorm2d(out_f,affine=True)
l_relu = nn.LeakyReLU(leak_val)
return nn.Sequential(conv,batch_norm,l_relu)
def forward(self,x):
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = self.block5(x)
return x
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
MNIST = d_utils.DataLoader(MNIST,batch_size,True)
epochs = 500
lr_crit = 0.00005
lr_gen = 0.00005
weight_clip = 0.01
leak_val = 0.2
crit = Critic(leak_val).to(device)
gen = Generator().to(device)
weight_init(crit)
weight_init(gen)
crit_optim = torch.optim.RMSprop(crit.parameters(),lr_crit)
gen_optim = torch.optim.RMSprop(gen.parameters(),lr_gen)
noise_dim = 128
fixed_noise = torch.rand((batch_size,noise_dim)).to(device)
writer_fake = SummaryWriter(r"C:\Users\mehme\Logs\fake")
writer_real = SummaryWriter(r"C:\Users\mehme\Logs\real")
n_crit_step = 5
a = 0
c_rl,g_rl = 0,0
for epoch in range(1,epochs+1):
for i,(real,_) in enumerate(MNIST):
batch_size = real.shape[0]
real = real.to(device).view(batch_size,1,64,64)
for j in range(n_crit_step):
noise = torch.rand((batch_size,noise_dim)).to(device)
fake = gen(noise).to(device)
real_pred = crit(real)
fake_pred = crit(fake.detach())
c_loss = torch.mean(fake_pred,0) - torch.mean(real_pred,0)
crit_optim.zero_grad()
c_loss.backward()
crit_optim.step()
for p in crit.parameters():
p.data.clamp_(-weight_clip,weight_clip)
c_rl += c_loss.detach()
fake_pred = crit(fake)
gen_loss = -torch.mean(fake_pred)
g_rl += gen_loss.detach()
gen_optim.zero_grad()
gen_loss.backward()
gen_optim.step()
print(gen_loss,c_loss,i,len(MNIST))
if not i%100:
with torch.no_grad():
gen.eval()
fake = gen(fixed_noise).to(device)
div_term = ((epoch-1)*len(MNIST))+i
print(f"c_loss:{c_rl/(div_term*n_crit_step)},gen_loss:{g_rl/div_term}")
gen.train()
torch.save(gen.state_dict(),"generator.pth")
torch.save(crit.state_dict(),"discriminator.pth")
fake = gen(fixed_noise)
data = real
img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
img_grid_real = torchvision.utils.make_grid(data, normalize=True)
writer_fake.add_image("Fake Images", img_grid_fake, global_step=a)
writer_real.add_image("Real Images", img_grid_real, global_step=a)
a += 1
I’ve checked if the generated images were actually getting better but they weren’t.