Hi @ tom
I am trying to implement the WGAN , Would you please help me if you saw any mistakes?
CriticIt=6
optimizerD = optim.RMSprop(netD.parameters(), lr = 0.0002)
optimizerG = optim.RMSprop(netG.parameters(), lr = 0.0002)
class Generator(nn.Module):
def __init__(self,ngpu,nz,ngf):
super(Generator, self).__init__()
self.ngpu=ngpu
self.nz=nz
self.ngf=ngf
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(self.nz, self.ngf * 8, 3, 1, 0, bias=False),
nn.BatchNorm2d(self.ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 3, 1, 0, bias=False),
nn.BatchNorm2d(self.ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d( self.ngf * 4, self.ngf * 2, 3, 1, 0, bias=False),
nn.BatchNorm2d(self.ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d( self.ngf*2, 1, 3, 1, 0, bias=False),nn.Sigmoid()
)
def forward(self, input):
return self.main(input)
## -------Define discriminator ----------------
class Discriminator993(nn.Module):
def __init__(self, ngpu,ndf):
super(Discriminator993, self).__init__()
self.ngpu = ngpu
self.ndf=ndf
self.l1= nn.Sequential(nn.Conv2d(1, self.ndf, 3, 1, 0, bias=False),nn.LeakyReLU(0.2, inplace=True))
self.l2=nn.Sequential(nn.Conv2d(self.ndf, self.ndf * 2, 3, 2, 0, bias=False),nn.BatchNorm2d(ndf * 2),nn.LeakyReLU(0.2, inplace=True))
self.drop_out2 = nn.Dropout(0.5)
self.l3= nn.Sequential(nn.Conv2d(self.ndf * 2, 1, 3, 2, 0, bias=False))
def forward(self, x):
out = self.l1(x)
out=self.l2(out)
out=self.drop_out2(out)
out=self.l3(out)
return out
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
# For each batch in the dataloader
for pos in zip(trainloader):
images1,labels=pos
images1=images1.float()
############################
# (1) Update D network:
###########################
## Train with all-real batch
netD.zero_grad()
## -------Train Discriminator more -----------
for Itr in range (CriticIt):
real_cpu = images1.to(device)
b_size = real_cpu.size(0)
# Forward pass real batch through D
netD=netD.float()
output = netD(real_cpu).view(-1)
## ---------loss of the discriminator on real ------------
errD_real = output.mean()
errD_real.backward()
# # -----------Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
netG=netG.float()
fake = netG(noise)
output = netD(fake.detach()).view(-1)
## ---------loss of the discriminator on fakes ------------
errD_fake=output.mean()
errD_fake.backward()
# Add the gradients from the all-real and all-fake batches
errD = errD_real - errD_fake
# Update D
optimizerD.step()
#--------------------Cliping -------------------
for p in netD.parameters():
p.data.clamp_(-0.01, 0.01)
############################
# (2) Update G network:
###########################
for p in netD.parameters():
p.requires_grad = False # to avoid computation
netG.zero_grad()
output = netD(fake44).view(-1)
errG=output.mean()
errG.backward()
# Update G
errG=-errG
optimizerG.step()
```