Hi, I am trying to train a progressive GAN. The losses for the 4x4 layer start off in a range which makes sense however after growing the network they explode and the gen images have rainbow patterns. I include my model code here and the training loop
# Let's define a function which can generate the conv block
def d_conv_block(in_channels, out_channels, kernel_size1=None, kernel_size2=None):
if kernel_size2 is not None:
block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size1, padding=(1,1)),
nn.LeakyReLU(0.2),
nn.Conv2d(out_channels, out_channels, kernel_size2),
nn.LeakyReLU(0.2),
)
else:
block = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size1, padding=(1,1)),
nn.LeakyReLU(0.2),
nn.Conv2d(in_channels, out_channels, kernel_size1, padding=(1,1)),
nn.LeakyReLU(0.2),
# Downsample
nn.AvgPool2d(kernel_size=(2,2)),
)
return block
def g_conv_block(in_channels, out_channels, kernel_size1=None, kernel_size2=None, upsample=False):
if upsample:
block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size1, padding=(1,1)),
LRN(),
nn.LeakyReLU(0.2),
nn.Conv2d(out_channels, out_channels, kernel_size1, padding=(1,1)),
LRN(),
nn.LocalResponseNorm(x.size(0), alpha=1, beta=2, k=10e-8),
nn.LeakyReLU(0.2),
)
else:
block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size1, padding=(3,3)),
nn.LeakyReLU(0.2),
nn.Conv2d(out_channels, out_channels, kernel_size2, padding=(1,1)),
LRN(),
nn.LeakyReLU(0.2),
)
return block
def d_output_layer(input_dim):
layer = nn.Linear(input_dim, 1)
return layer
def from_to_RGB(in_channels=None, out_channels=None):
block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=(1,1)),
nn.LeakyReLU(0.2),
)
return block
def upsample(channels):
return nn.ConvTranspose2d(in_channels=channels, out_channels=channels, kernel_size=2, stride=2)
class Mbatch_stddev(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
N, _, H, W = x.shape
# First calculate the stddev for each feature in each spatial location over the batch
# Which means calculate the stddev of each feature map
featuremap_stddevs = torch.std(x, dim=0, unbiased=False)
# Then average these estimates over all features and spatial locations to arrive at a single value
mean_stddev = torch.mean(featuremap_stddevs)
stddev_featuremap = mean_stddev * torch.ones((N, 1, H, W), device=x.device)
x = torch.cat((x, stddev_featuremap), dim=1)
return x
class LRN(nn.Module):
def __init__(self, epsilon=1e-8):
super(LRN, self).__init__()
self.epsilon = epsilon
def forward(self, x):
square_sum = torch.pow(x, 2).sum(dim=1, keepdim=True) # Sum across all channels
norm_factor = torch.sqrt(square_sum / x.size(1) + self.epsilon) # Divide by N (number of channels)
return x / norm_factor # Normalize
class Generator_32(nn.Module):
def __init__(self):
super().__init__()
# Lets try using convtranspose2d for upsample
# self.upsample = nn.Upsample(scale_factor=2, mode='nearest'
self.blocks = nn.ModuleList()
self.to_rgb = nn.Identity()
self.res_to_rgb = nn.Identity()
self.upsample_res = nn.Identity()
# For managing res connection
# self.block_outputs = [] This approach doesn't work for multiple inputs (e.g batches in a dataset)
# self.prev_to_rgb = nn.Identity()
# self.res_flag = False
def forward(self, x, alpha):
block_outputs = []
for i, block in enumerate(self.blocks):
'''if i == len(self.blocks) - 1:
res_x = torch.clone(x)'''
x = block(x)
block_outputs.append(x)
x = self.to_rgb(x)
if len(block_outputs) >= 2:
#if self.res_flag:
res_x = block_outputs[-2] # -2 accesses second to the front of list
res_x = self.upsample_res(res_x)
res_x = self.res_to_rgb(res_x)
#res_x = self.prev_to_rgb(res_x)
#res_x = self.upsample_res(res_x)
out = ((1-alpha) * res_x) + (alpha * x)
else:
out = x
return out
def add_gen_block(self, layer_num):
all_out_channels = [512, 256, 128, 64] # The number of output channels for layer 1 - 4
# new_params keeps track of new params for optimiser
new_params = []
if layer_num == 0:
self.blocks.append(g_conv_block(in_channels=512, out_channels=512, kernel_size1=(4,4), kernel_size2=(3,3))).to(device)
new_params.extend(self.blocks[-1].parameters())
self.to_rgb = from_to_RGB(in_channels=512, out_channels=3).to(device)
new_params.extend(self.to_rgb.parameters())
else:
in_channels = all_out_channels[layer_num] * 2
out_channels = all_out_channels[layer_num]
# Upsample before appending new layer
self.blocks.append(upsample(in_channels).to(device))
new_params.extend(self.blocks[-1].parameters())
self.blocks.append(g_conv_block(in_channels=in_channels, out_channels=out_channels, kernel_size1=(3,3), kernel_size2=(3,3), upsample=True).to(device))
new_params.extend(self.blocks[-1].parameters())
# Store old to_rgb
# self.prev_to_rgb = self.to_rgb
self.to_rgb = from_to_RGB(in_channels=out_channels, out_channels=3).to(device)
new_params.extend(self.to_rgb.parameters())
res_out_channels = all_out_channels[layer_num-1]
self.upsample_res = upsample(channels=res_out_channels).to(device)
new_params.extend(self.upsample_res.parameters())
self.res_to_rgb = from_to_RGB(in_channels=res_out_channels, out_channels=3).to(device)
new_params.extend(self.res_to_rgb.parameters())
self.res_flag = True
return new_params
g_32 = Generator_32()
g_32 = g_32.to(device)
class Discriminator_32(nn.Module):
def __init__(self):
super().__init__()
self.from_rgb = nn.Identity()
self.res_fromRGB = nn.Identity()
self.blocks = nn.ModuleList()
self.down = nn.AvgPool2d(kernel_size=(2,2)).to(device) # This isnt used for the layers but the res connection
# For managing res connection
self.res_flag = False
self.FC1 = nn.Identity()
def forward(self, x, alpha=0):
res_x = torch.clone(x) # The input to the model is passed to the output of the first layer in the network
#print(f'Res_x.shape before: {res_x.shape}')
x = self.from_rgb(x)
for i, block in enumerate(self.blocks):
x = block(x)
#print('BLOCK' ,block, '\n', x.shape)
if i == 0 and self.res_flag:
res_x = self.down(res_x)
res_x = self.res_fromRGB(res_x)
#print('RES\n', res_x.shape, x.shape)
# Wrong placement here, it needs to be turned off after passing through block 0
#self.res_flag = False
x = ((1-alpha) * res_x) + (alpha * x)
# Last FC layer
x = x.view(x.size(0), -1) # Reshape the output, i.e. flatten it
self.FC1 = d_output_layer(x.size(1)).to(x.device)
#print(x.shape)
x = self.FC1(x)
return x
def add_disc_block(self, layer_num):
all_in_channels = [512, 256, 128, 64, 32]
# I need a way to keep track of new parameters so we can add them to the optimiser
new_params = []
if layer_num == 0:
self.from_rgb = from_to_RGB(in_channels=3, out_channels=512).to(device)
new_params.extend(self.from_rgb.parameters())
self.blocks.append(Mbatch_stddev()).to(device)
new_params.extend(self.blocks[-1].parameters())
self.blocks.append(d_conv_block(in_channels=513, out_channels=512, kernel_size1=(3,3), kernel_size2=(4,4))).to(device)
new_params.extend(self.blocks[-1].parameters())
else:
self.res_fromRGB = from_to_RGB(in_channels=3, out_channels=all_in_channels[layer_num-1]).to(device)
new_params.extend(self.res_fromRGB.parameters())
self.from_rgb = from_to_RGB(in_channels=3, out_channels=all_in_channels[layer_num]).to(device)
new_params.extend(self.from_rgb.parameters())
new_block = [d_conv_block(in_channels=all_in_channels[layer_num], out_channels=all_in_channels[layer_num-1], kernel_size1=(3,3)).to(device)]
new_params.extend(new_block[0].parameters())
self.blocks = nn.ModuleList(new_block + list(self.blocks)).to(device)
self.res_flag = True
return new_params
d_32 = Discriminator_32()
d_32 = d_32.to(device)
for layer in range(4):
print(f'Training layer: {layer+1}')
alpha = 0
# Add the G and D blocks
d_new_params = d_32.add_disc_block(layer_num=layer)
g_new_params = g_32.add_gen_block(layer_num=layer)
print(d_32)
print(g_32)
for epoch_grow in range(25):
for i, data in enumerate(dataloader):
real_images, _ = data
real_images = real_images.to(device)
noise_tensor = torch.randn(batch_size, 512, 1, 1, device=device)
with torch.no_grad():
gen_images = g_32(noise_tensor, alpha=alpha)
real_images = F.interpolate(real_images, size=gen_images.shape[2:])
gen_labels = torch.zeros((batch_size, 1)).to(device)
real_labels = torch.ones((batch_size, 1)).to(device)
combined_images = torch.cat((real_images, gen_images))
combined_labels = torch.cat((real_labels, gen_labels))
#print(gen_images.shape, real_images.shape)
# First update the D model
d_32.zero_grad()
d_outputs_combined = d_32(combined_images, alpha=alpha)
# Init the first D optim
if optim_D is None:
optim_D = torch.optim.Adam(d_32.parameters(), lr=0.00001, betas=(0, 0.99), eps=10**(-8))
else:
# Safely add new discriminator parameters
existing_d_params = set(p for group in optim_D.param_groups for p in group['params'])
new_d_params = set(d_new_params) - existing_d_params
if new_d_params:
optim_D.add_param_group({'params': list(new_d_params)})
# Init the first G optim
if optim_G is None:
optim_G = torch.optim.Adam(g_32.parameters(), lr=0.00001, betas=(0, 0.99), eps=10**(-8))
else:
# Safely add new generator parameters
existing_g_params = set(p for group in optim_G.param_groups for p in group['params'])
new_g_params = set(g_new_params) - existing_g_params
if new_g_params:
optim_G.add_param_group({'params': list(new_g_params)})
loss_d = criterion(d_outputs_combined, combined_labels)
#loss_d, _ = criterion(d_32, real_images, gen_images.detach(), alpha)
loss_d.backward()
optim_D.step()
# Generate new images for updating G
noise_tensor = torch.randn((batch_size, 512, 1, 1)).to(device)
# Next update the G model,
#if i % 2 == 0: # Update generator less frequently
g_32.zero_grad()
gen_images = g_32(noise_tensor, alpha=alpha)
d_outputs_generated = d_32(gen_images, alpha=alpha)
loss_g = criterion(d_outputs_generated, real_labels)
#_, loss_g = criterion(d_32, real_images, gen_images, alpha)
loss_g.backward()
optim_G.step()
imshow(torchvision.utils.make_grid(gen_images.cpu()))
print(f'Layer {layer+1}: Loss_D: {loss_d.item()}, Loss_G: {loss_g.item()}')
alpha += 1/25
alpha = round(alpha, 2)
print(f'Alpha after grow: {alpha}')
for epoch_train in range(25):
for i, data in enumerate(dataloader):
real_images, _ = data
real_images = real_images.to(device)
noise_tensor = torch.randn((batch_size, 512, 1, 1), device=device)
with torch.no_grad():
gen_images = g_32(noise_tensor, alpha=alpha)
real_images = F.interpolate(real_images, size=gen_images.shape[2:])
gen_labels = torch.zeros((batch_size, 1)).to(device)
real_labels = torch.ones((batch_size, 1)).to(device)
combined_images = torch.cat((real_images, gen_images))
combined_labels = torch.cat((real_labels, gen_labels))
# First update the D model
d_32.zero_grad()
d_outputs_combined = d_32(combined_images, alpha=alpha)
loss_d = criterion(d_outputs_combined, combined_labels)
#loss_d, _ = criterion(d_32, real_images, gen_images.detach(), alpha)
loss_d.backward()
optim_D.step()
# Generate new images for updating G
noise_tensor = torch.randn((batch_size, 512, 1, 1)).to(device)
# Next update the G model,
#if i % 2 == 0: # Update generator less frequently
g_32.zero_grad()
gen_images = g_32(noise_tensor, alpha=alpha) # Gen new images for training G
d_outputs_generated = d_32(gen_images, alpha=alpha)
loss_g = criterion(d_outputs_generated, real_labels)
#_, loss_g = criterion(d_32, real_images, gen_images, alpha)
loss_g.backward()
optim_G.step()
print(f'FINAL | Layer {layer+1}: Loss_D: {loss_d.item()}, Loss_G: {loss_g.item()}')
imshow(torchvision.utils.make_grid(real_images.cpu()))
imshow(torchvision.utils.make_grid(gen_images.cpu()))
I am unsure of how to proceed and fix this issue, I’ve been stumped for days now. I would like to ask if anybody can spot an issue in my code or is familiar with this sort of error in training progressive GANs?