I am trying to train a Deep Convolutional GAN (DCGAN) with 3D density data of shape (128,128,128) which lies in the range (-1,1700). I standardized and then normalized the data so that now it lies in the range (-1,1). I am using learning rates of 1e-5 for both the discriminator and the generator. The latent vector is a standard normal distribution, the loss function is a BCE Loss but it seems that the loss function is saturating and staying the same after some initial iterations:
The Disciminator and Generator networks have the following architecture:
class GeneratorNet(nn.Module):
def brick(self, inchn, outchn, k, s, p, bias=False):
return nn.ConvTranspose3d(inchn, outchn, k, s, p, bias=False)
def postop(self, chn):
return nn.Sequential(
nn.BatchNorm3d(chn),
nn.ReLU(True),)
def __init__(self, ngpu):
super(GeneratorNet, self).__init__()
self.ngpu = ngpu
#--input or z is BS,200
self.linear1 = torch.nn.Linear(200, 256)
#---reshape into BS,256,1,1,1
self.brick1 = self.brick (256, 128, k=4, s=2, p=1) #---BS,128,2,2,2
self.postop1 = self.postop(128)
self.brick2 = self.brick (128, 64, k=4, s=2, p=1) #---BS,64,4,4,4
self.postop2 = self.postop(64)
self.brick3 = self.brick (64, 32, k=4, s=2, p=1) #---BS,32,8,8,8
self.postop3 = self.postop(32)
self.brick4 = self.brick (32, 16, k=4, s=2, p=1) #---BS,16,16,16,16
self.postop4 = self.postop(16)
self.brick5 = self.brick (16, 8, k=4, s=2, p=1) #---BS,8,32,32,32
self.postop5 = self.postop(8)
self.brick6 = self.brick (8, 4, k=4, s=2, p=1) #---BS,4,64,64,64
self.postop6 = self.postop(4)
self.brick7 = self.brick (4, 1, k=4, s=2, p=1) #---BS,1,128,128,128
self.postop7 = self.postop(1)
self.final_activation = torch.nn.Tanh()
def forward(self, z):
#--Generator takes z in the shape of BS,200
y = self.linear1(z)
bs,lv = y.shape
y = y.reshape(bs, lv, 1, 1, 1)
y = self.brick1(y)
y = self.postop1(y)
y = self.brick2(y)
y = self.postop2(y)
y = self.brick3(y)
y = self.postop3(y)
y = self.brick4(y)
y = self.postop4(y)
y = self.brick5(y)
y = self.postop5(y)
y = self.brick6(y)
y = self.postop6(y)
y = self.brick7(y)
y = self.final_activation(y)
return y
class DiscriminatorNet(nn.Module):
def brick(self, inchn, outchn, k, s, p, bias=False):
return nn.Conv3d(inchn, outchn, k, s, p, bias=False)
def firstpostop(self):
return nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),)
def postop(self, chn):
return nn.Sequential(
nn.BatchNorm3d(chn, affine=True),
nn.LeakyReLU(0.2, inplace=True),)
def __init__(self, ngpu):
super(DiscriminatorNet, self).__init__()
self.ngpu = ngpu
#--input is BS, 1, 128,128,128
self.brick1 = self.brick (1, 4, k=4, s=2, p=1) #--BS,4,64,64,64
self.postop1 = self.postop(4)
self.brick2 = self.brick (4, 8, k=4, s=2, p=1) #--BS,8,32,32,32
self.postop2 = self.postop(8)
self.brick3 = self.brick (8, 16, k=4, s=2, p=1) #--BS,16,16,16,16
self.postop3 = self.postop(16)
self.brick4 = self.brick (16, 32, k=4, s=2, p=1) #--BS,32,8,8,8
self.postop4 = self.postop(32)
self.brick5 = self.brick (32, 64, k=4, s=2, p=1) #--BS,64,4,4,4
self.postop5 = self.postop(64)
self.brick6 = self.brick (64, 128, k=4, s=2, p=1) #--BS,128,2,2,2
self.postop6 = self.postop(128)
self.brick7 = self.brick (128, 256, k=2, s=1, p=0) #---BS,256,1,1,1
self.linear1 = torch.nn.Linear(256*1*1*1, 1)
self.final_activation = torch.nn.Sigmoid()
def forward(self, g_out):
y = self.brick1(g_out)
y = self.postop1(y)
y = self.brick2(y)
y = self.postop2(y)
y = self.brick3(y)
y = self.postop3(y)
y = self.brick4(y)
y = self.postop4(y)
y = self.brick5(y)
y = self.postop5(y)
y = self.brick6(y)
y = self.postop6(y)
y = self.brick7(y)
bs,ch,h,w,d = y.shape
y = y.reshape(bs,ch*h*w*d)
y = self.linear1(y)
y = self.final_activation(y)
return y
Discriminator is updated only if the D(G(z)) > PROB, where 3 different values of PROB were tested. The training loopis as follows:
PROB = 0.2 #---Select from 0.0 (standard), 0.2 and 0.25
ngpu = 2
batch_size = 4
num_workers = 4
nz = 200
G_lr, G_betas = 1e-5, (0.5,0.999) #5e-5, 0.5 #--0.0001, beta=0.6
D_lr, D_betas = 1e-5, (0.5,0.999) #5e-5, 0.5
num_epochs = 50
print_freq = 200
save_freq = 2000
iters = 1
real_label = 1.
fake_label = 0.
fixed_noise = torch.normal(mean=0.0, std=1.0, size=(2,nz), device=main_collector)
D_optimizer = optim.Adam(Discriminator.parameters(), lr=D_lr, betas=D_betas)
G_optimizer = optim.Adam(Generator.parameters() , lr=G_lr, betas=G_betas)
criterion = torch.nn.BCELoss()
for epoch in range(num_epochs):
for i, data in enumerate(mydataloader, 1):
##############################
######## DISCRIMINATOR #######
##############################
Discriminator.zero_grad() #---Set gradients to 0. Use model.zero_grad() if >1 optimizers for the same model
real_data = data[0].to(main_collector) #---get real data and conditional parameters
cbs = real_data.shape[0] #-----------current batch size (because batch size can be different at the end)
label = torch.full((cbs,), real_label, dtype=torch.float, device=main_collector) #; print('real_data :',
#--TRAIN DISCRIMINATOR WITH ALL-REAL BATCH
real_validity = Discriminator(real_data).view(-1)
D_loss_real = criterion(real_validity, label)
D_loss_real.backward()
D_x = real_validity.mean().item()
#--TRAIN DISCRIMINATOR WITH ALL-FAKE BATCH
z = torch.normal(mean=0.0, std=1.0, size=(cbs,nz), device=main_collector)
fake_data = Generator(z)
label.fill_(fake_label)
fake_validity = Discriminator(fake_data.detach()).view(-1)
D_loss_fake = criterion(fake_validity, label)
D_loss_fake.backward()
D_G_z1 = fake_validity.mean().item()
#--TOTAL DISCRIMINATOR LOSS
D_loss = D_loss_real + D_loss_fake
if D_G_z1 > PROB:
D_optimizer.step()
##############################
######## GENERATOR ###########
##############################
Generator.zero_grad()
label.fill_(real_label) # [1,1,1,1,1,...] fake labels are real for generator cost
output = Discriminator(fake_data).view(-1) # We just updated D, so perform another forward pass of all-fake batch through D
G_loss = criterion(output, label) #---Calculate G's loss based on this output
G_loss.backward()
D_G_z2 = output.mean().item()
G_optimizer.step()
##############################
######## DETAILS ###########
##############################
if i % print_freq == 0:
print('Epoch: %d/%d \tIteration: %d/%d \tD_loss: %.3f, G_loss: %.3f, \tD(x): %.4f, D(G(z)): %.4f/%.4f'
% (epoch+1, num_epochs, iters, int(num_epochs*num_batches), D_loss.item(), G_loss.item(), D_x, D_G_z1,D_G_z2))
if (iters % save_freq == 0):
with torch.no_grad():
fake_data = Generator(fixed_noise).detach().cpu().numpy()
img_list.append(fake_data)
np.save("img_list.npy", np.array(img_list))
iters += 1
The generated images (right) are not even close to the ones trained on (left).
It seems that there is something inherently incorrect with my approach. What am I doing wrong here?