Progressive GAN Weights exploding after adding new layer

Hi, I am trying to train a progressive GAN. The losses for the 4x4 layer start off in a range which makes sense however after growing the network they explode and the gen images have rainbow patterns. I include my model code here and the training loop

# Let's define a function which can generate the conv block
def d_conv_block(in_channels, out_channels, kernel_size1=None, kernel_size2=None):
    if kernel_size2 is not None:
        block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size1, padding=(1,1)),
            nn.LeakyReLU(0.2),
            nn.Conv2d(out_channels, out_channels, kernel_size2),
            nn.LeakyReLU(0.2),
        )
    else:
        block = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size1, padding=(1,1)),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels, out_channels, kernel_size1, padding=(1,1)),
            nn.LeakyReLU(0.2),
            # Downsample
            nn.AvgPool2d(kernel_size=(2,2)),
        )
    
    return block

def g_conv_block(in_channels, out_channels, kernel_size1=None, kernel_size2=None, upsample=False):
    if upsample:
        block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size1, padding=(1,1)),
            LRN(),
            nn.LeakyReLU(0.2),
            nn.Conv2d(out_channels, out_channels, kernel_size1, padding=(1,1)),
            LRN(),
            nn.LocalResponseNorm(x.size(0), alpha=1, beta=2, k=10e-8),
            nn.LeakyReLU(0.2),
        )
    else:
        block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size1, padding=(3,3)),
            nn.LeakyReLU(0.2),
            nn.Conv2d(out_channels, out_channels, kernel_size2, padding=(1,1)),
            LRN(),
            nn.LeakyReLU(0.2),
        )
    
    return block

def d_output_layer(input_dim):
    layer = nn.Linear(input_dim, 1)
    return layer

def from_to_RGB(in_channels=None, out_channels=None):
    block = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=(1,1)),
        nn.LeakyReLU(0.2),
    )
    return block

def upsample(channels):
    return nn.ConvTranspose2d(in_channels=channels, out_channels=channels, kernel_size=2, stride=2)

class Mbatch_stddev(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        N, _, H, W = x.shape
    
        # First calculate the stddev for each feature in each spatial location over the batch
        # Which means calculate the stddev of each feature map
        featuremap_stddevs = torch.std(x, dim=0, unbiased=False)
        # Then average these estimates over all features and spatial locations to arrive at a single value
        mean_stddev = torch.mean(featuremap_stddevs)
        stddev_featuremap = mean_stddev * torch.ones((N, 1, H, W), device=x.device)

        x = torch.cat((x, stddev_featuremap), dim=1)

        return x

class LRN(nn.Module):
    def __init__(self, epsilon=1e-8):
        super(LRN, self).__init__()
        self.epsilon = epsilon
    
    def forward(self, x):
        square_sum = torch.pow(x, 2).sum(dim=1, keepdim=True)  # Sum across all channels
        norm_factor = torch.sqrt(square_sum / x.size(1) + self.epsilon)  # Divide by N (number of channels)
        return x / norm_factor  # Normalize
        
class Generator_32(nn.Module):
    def __init__(self):
        super().__init__()
        # Lets try using convtranspose2d for upsample
        # self.upsample = nn.Upsample(scale_factor=2, mode='nearest'
        
        self.blocks = nn.ModuleList()
        
        self.to_rgb = nn.Identity()
        self.res_to_rgb = nn.Identity()
        
        self.upsample_res = nn.Identity()
        
        # For managing res connection
        # self.block_outputs = [] This approach doesn't work for multiple inputs (e.g batches in a dataset)
        # self.prev_to_rgb = nn.Identity()
        # self.res_flag = False
        
    def forward(self, x, alpha):
        block_outputs = []
        
        for i, block in enumerate(self.blocks):
            '''if i == len(self.blocks) - 1:
                res_x = torch.clone(x)'''
            x = block(x)
        
        block_outputs.append(x)
        x = self.to_rgb(x)
        
        if len(block_outputs) >= 2:
        #if self.res_flag:
            res_x = block_outputs[-2]  # -2 accesses second to the front of list
            res_x = self.upsample_res(res_x)
            res_x = self.res_to_rgb(res_x)
            #res_x = self.prev_to_rgb(res_x)
            #res_x = self.upsample_res(res_x)
            
            out = ((1-alpha) * res_x) + (alpha * x)
        else:
            out = x
        
        return out
    
    def add_gen_block(self, layer_num):
        all_out_channels = [512, 256, 128, 64]  # The number of output channels for layer 1 - 4
        # new_params keeps track of new params for optimiser
        new_params = []
        
        if layer_num == 0:
            self.blocks.append(g_conv_block(in_channels=512, out_channels=512, kernel_size1=(4,4), kernel_size2=(3,3))).to(device)
            new_params.extend(self.blocks[-1].parameters())
            self.to_rgb = from_to_RGB(in_channels=512, out_channels=3).to(device)
            new_params.extend(self.to_rgb.parameters())
        else:
            in_channels = all_out_channels[layer_num] * 2
            out_channels = all_out_channels[layer_num]
            # Upsample before appending new layer
            self.blocks.append(upsample(in_channels).to(device))
            new_params.extend(self.blocks[-1].parameters())
            self.blocks.append(g_conv_block(in_channels=in_channels, out_channels=out_channels, kernel_size1=(3,3), kernel_size2=(3,3), upsample=True).to(device))
            new_params.extend(self.blocks[-1].parameters())
            
            # Store old to_rgb
            # self.prev_to_rgb = self.to_rgb
            
            self.to_rgb = from_to_RGB(in_channels=out_channels, out_channels=3).to(device)
            new_params.extend(self.to_rgb.parameters())
            
            res_out_channels = all_out_channels[layer_num-1]
            self.upsample_res = upsample(channels=res_out_channels).to(device)
            new_params.extend(self.upsample_res.parameters())
            self.res_to_rgb = from_to_RGB(in_channels=res_out_channels, out_channels=3).to(device)
            new_params.extend(self.res_to_rgb.parameters())
            
            self.res_flag = True
        
        return new_params
    
g_32 = Generator_32()
g_32 = g_32.to(device)

class Discriminator_32(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.from_rgb = nn.Identity()
        self.res_fromRGB = nn.Identity()
        
        self.blocks = nn.ModuleList()
        
        self.down = nn.AvgPool2d(kernel_size=(2,2)).to(device)  # This isnt used for the layers but the res connection
        
        # For managing res connection
        self.res_flag = False
        
        self.FC1 = nn.Identity()
    
    def forward(self, x, alpha=0):
        res_x = torch.clone(x)  # The input to the model is passed to the output of the first layer in the network
        #print(f'Res_x.shape before: {res_x.shape}')
        x = self.from_rgb(x)
        
        for i, block in enumerate(self.blocks):
            x = block(x)
            #print('BLOCK' ,block, '\n', x.shape)
            if i == 0 and self.res_flag:
                res_x = self.down(res_x)
                res_x = self.res_fromRGB(res_x)
                #print('RES\n', res_x.shape, x.shape)
                # Wrong placement here, it needs to be turned off after passing through block 0
                #self.res_flag = False
                x = ((1-alpha) * res_x) + (alpha * x)
        
        # Last FC layer
        x = x.view(x.size(0), -1) # Reshape the output, i.e. flatten it 
        self.FC1 = d_output_layer(x.size(1)).to(x.device)
        #print(x.shape)
        x = self.FC1(x)
        
        return x
    
    def add_disc_block(self, layer_num):
        all_in_channels = [512, 256, 128, 64, 32]
        # I need a way to keep track of new parameters so we can add them to the optimiser
        new_params = []
        
        if layer_num == 0:
            self.from_rgb = from_to_RGB(in_channels=3, out_channels=512).to(device)
            new_params.extend(self.from_rgb.parameters())
            
            self.blocks.append(Mbatch_stddev()).to(device)
            new_params.extend(self.blocks[-1].parameters())
            self.blocks.append(d_conv_block(in_channels=513, out_channels=512, kernel_size1=(3,3), kernel_size2=(4,4))).to(device)
            new_params.extend(self.blocks[-1].parameters())
        else:
            self.res_fromRGB = from_to_RGB(in_channels=3, out_channels=all_in_channels[layer_num-1]).to(device)
            new_params.extend(self.res_fromRGB.parameters())
            self.from_rgb = from_to_RGB(in_channels=3, out_channels=all_in_channels[layer_num]).to(device)
            new_params.extend(self.from_rgb.parameters())
            new_block = [d_conv_block(in_channels=all_in_channels[layer_num], out_channels=all_in_channels[layer_num-1], kernel_size1=(3,3)).to(device)]
            new_params.extend(new_block[0].parameters())
            self.blocks = nn.ModuleList(new_block + list(self.blocks)).to(device)
        
            self.res_flag = True
            
        return new_params
        
d_32 = Discriminator_32() 
d_32 = d_32.to(device)

for layer in range(4):
    print(f'Training layer: {layer+1}')
    alpha = 0
    
    # Add the G and D blocks
    d_new_params = d_32.add_disc_block(layer_num=layer)
    g_new_params = g_32.add_gen_block(layer_num=layer)
    
    print(d_32)
    print(g_32)
    
    
    for epoch_grow in range(25):
        for i, data in enumerate(dataloader):
            real_images, _ = data
            real_images = real_images.to(device)
        
            noise_tensor = torch.randn(batch_size, 512, 1, 1, device=device)

            with torch.no_grad():
                gen_images = g_32(noise_tensor, alpha=alpha)
            
            real_images = F.interpolate(real_images, size=gen_images.shape[2:])
            
            gen_labels = torch.zeros((batch_size, 1)).to(device)
            real_labels = torch.ones((batch_size, 1)).to(device)

            combined_images = torch.cat((real_images, gen_images))
            combined_labels = torch.cat((real_labels, gen_labels))
            
            #print(gen_images.shape, real_images.shape)
            
            # First update the D model
            d_32.zero_grad()
            d_outputs_combined = d_32(combined_images, alpha=alpha) 
            
            # Init the first D optim
            if optim_D is None:
                optim_D = torch.optim.Adam(d_32.parameters(), lr=0.00001, betas=(0, 0.99), eps=10**(-8))
            else:
                # Safely add new discriminator parameters
                existing_d_params = set(p for group in optim_D.param_groups for p in group['params'])
                new_d_params = set(d_new_params) - existing_d_params
                if new_d_params:
                    optim_D.add_param_group({'params': list(new_d_params)})
                    
            # Init the first G optim
            if optim_G is None:
                optim_G = torch.optim.Adam(g_32.parameters(), lr=0.00001, betas=(0, 0.99), eps=10**(-8))
            else:
                # Safely add new generator parameters
                existing_g_params = set(p for group in optim_G.param_groups for p in group['params'])
                new_g_params = set(g_new_params) - existing_g_params
                if new_g_params:
                    optim_G.add_param_group({'params': list(new_g_params)})

            loss_d = criterion(d_outputs_combined, combined_labels)
            #loss_d, _ = criterion(d_32, real_images, gen_images.detach(), alpha)
            loss_d.backward()
            optim_D.step()
            
            # Generate new images for updating G
            noise_tensor = torch.randn((batch_size, 512, 1, 1)).to(device)

            # Next update the G model, 
            #if i % 2 == 0:  # Update generator less frequently
            g_32.zero_grad()
            gen_images = g_32(noise_tensor, alpha=alpha)  
            d_outputs_generated = d_32(gen_images, alpha=alpha)
            loss_g = criterion(d_outputs_generated, real_labels)
            #_, loss_g = criterion(d_32, real_images, gen_images, alpha)
            loss_g.backward()
            optim_G.step()
            
        imshow(torchvision.utils.make_grid(gen_images.cpu()))
        print(f'Layer {layer+1}: Loss_D: {loss_d.item()}, Loss_G: {loss_g.item()}')
        
        alpha += 1/25
        alpha = round(alpha, 2)
        
    print(f'Alpha after grow: {alpha}')
    for epoch_train in range(25):
        for i, data in enumerate(dataloader):
            real_images, _ = data
            real_images = real_images.to(device)

            noise_tensor = torch.randn((batch_size, 512, 1, 1), device=device)

            with torch.no_grad():
                gen_images = g_32(noise_tensor, alpha=alpha)

            real_images = F.interpolate(real_images, size=gen_images.shape[2:])
            
            gen_labels = torch.zeros((batch_size, 1)).to(device)
            real_labels = torch.ones((batch_size, 1)).to(device)

            combined_images = torch.cat((real_images, gen_images))
            combined_labels = torch.cat((real_labels, gen_labels))

            # First update the D model
            d_32.zero_grad()
            d_outputs_combined = d_32(combined_images, alpha=alpha) 
            
            loss_d = criterion(d_outputs_combined, combined_labels)
            #loss_d, _ = criterion(d_32, real_images, gen_images.detach(), alpha)
            loss_d.backward()
            optim_D.step()
            
            # Generate new images for updating G
            noise_tensor = torch.randn((batch_size, 512, 1, 1)).to(device)

            # Next update the G model, 
            #if i % 2 == 0:  # Update generator less frequently
            g_32.zero_grad()
            gen_images = g_32(noise_tensor, alpha=alpha)  # Gen new images for training G
            d_outputs_generated = d_32(gen_images, alpha=alpha)
            loss_g = criterion(d_outputs_generated, real_labels)
            #_, loss_g = criterion(d_32, real_images, gen_images, alpha)
            loss_g.backward()
            optim_G.step()
    
    print(f'FINAL | Layer {layer+1}: Loss_D: {loss_d.item()}, Loss_G: {loss_g.item()}')
    imshow(torchvision.utils.make_grid(real_images.cpu()))
    imshow(torchvision.utils.make_grid(gen_images.cpu()))

I am unsure of how to proceed and fix this issue, I’ve been stumped for days now. I would like to ask if anybody can spot an issue in my code or is familiar with this sort of error in training progressive GANs?