GAN loss remains same during training

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter

class Discriminator(nn.Module):
def init(self, img_channels=3, features=64): # 3x64x64
super().init()
self.disc = nn.Sequential(
self._make_conv2d_block(img_channels, features, use_bn=False), # 64x32x32

        self._make_conv2d_block(features, features*2), # 128x16x16
        self._make_conv2d_block(features*2, features*4), # 256x8x8
        self._make_conv2d_block(features*4, features*8), # 512x4x4

        self._make_conv2d_block(features*8, 1, 4, 2, 0, use_act=False), # 1x1x1
        nn.Sigmoid(),
    )

def _make_conv2d_block(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, use_bn=True, use_act=True, leak=0.2):
    layers = [
        nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias=False,
        ),
    ]
    if use_bn:
        layers.insert(1, nn.BatchNorm2d(out_channels))
    if use_act:
        layers.insert(2, nn.LeakyReLU(leak))
    return nn.Sequential(*layers)

def forward(self, x):
    return self.disc(x)

class Generator(nn.Module):
def init(self, z_dim=100, features=64, img_channels=3): # 100x1x1
super().init()
self.gen = nn.Sequential(
self._make_convT2d_block(z_dim, features*16, 4, 1, 0), # 1024x4x4

        self._make_convT2d_block(features*16, features*8), # 512x8x8
        self._make_convT2d_block(features*8, features*4), # 256x16x16
        self._make_convT2d_block(features*4, features*2), # 512x32x32

        self._make_convT2d_block(features*2, img_channels, use_bn=False, use_act=False), # 3x64x64
        nn.Tanh(),
    )

def _make_convT2d_block(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, use_bn=True, use_act=True):
    layers = [
        nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias=False,
        )
    ]
    if use_bn:
        layers.insert(1, nn.BatchNorm2d(out_channels))
    if use_act:
        layers.insert(2, nn.ReLU())
    return nn.Sequential(*layers)        

def forward(self, x):
    return self.gen(x)

def initialize_weights(model):
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
nn.init.normal_(
tensor=m.weight.data,
mean=0.0,
std=0.02,
)

BATCH_SIZE, NUM_CHANNELS, H, W = 8, 3, 64, 64
Z_DIM = 100
FEATURES = 64
img_test = torch.randn((BATCH_SIZE, NUM_CHANNELS, H, W))
disc = Discriminator(NUM_CHANNELS, FEATURES)
assert disc(img_test).shape == (BATCH_SIZE, 1, 1, 1), ‘Disc Test Failed’
noise_test = torch.randn((BATCH_SIZE, Z_DIM, 1, 1))
gen = Generator(Z_DIM, FEATURES, NUM_CHANNELS)
assert gen(noise_test).shape == (BATCH_SIZE, NUM_CHANNELS, H, W), ‘Gen Test Failed’
print(‘Success!’) # Success

device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)
NUM_CHANNELS = 1
BATCH_SIZE = 64
IMG_SIZE = 64
Z_DIM = 100
FEATURES_DISC = 64
FEATURES_GEN = 64
LR = 2e-4
NUM_EPOCHS = 5
DISC_ITER = 1
GRID_SHOW = 8

disc = Discriminator(NUM_CHANNELS, FEATURES_DISC).to(device)
gen = Generator(Z_DIM, FEATURES_GEN, NUM_CHANNELS).to(device)
initialize_weights(disc), initialize_weights(gen)
fixed_noise = torch.randn((GRID_SHOW, Z_DIM, 1, 1)).to(device)

transformations = transforms.Compose(
[
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(
[0.5 for _ in range(NUM_CHANNELS)], [0.5 for _ in range(NUM_CHANNELS)]
),
]
)
images = datasets.MNIST(root=‘mnist’, train=True, transform=transformations, download=True)
images_loader = DataLoader(images, batch_size=BATCH_SIZE, shuffle=True)

optim_disc = optim.Adam(disc.parameters(), lr=LR, betas=(0.5, 0.999))
optim_gen = optim.Adam(gen.parameters(), lr=LR, betas=(0.5, 0.999))
criterion = nn.BCELoss()

writer_fake = SummaryWriter(‘logs/fake’)
writer_real = SummaryWriter(‘logs/real’)
steps = 1

disc.train(), gen.train()

for epoch in range(NUM_EPOCHS):
for batch_idx, (real, _) in enumerate(images_loader):
real = real.to(device)
for _ in range(DISC_ITER):
noise = torch.randn((real.shape[0], Z_DIM, 1, 1)).to(device)
fake = gen(noise)
disc_real = disc(real).view(-1)
loss_disc_real = criterion(
input=disc_real,
target=torch.ones_like(disc_real),
)
disc_fake = disc(fake.detach()).view(-1)
loss_disc_fake = criterion(
input=disc_fake,
target=torch.zeros_like(disc_fake),
)
loss_disc = loss_disc_real + loss_disc_fake
disc.zero_grad()
loss_disc.backward()
optim_disc.step()

    disc_fake = disc(fake).view(-1)
    loss_gen = criterion(
        input=disc_fake,
        target=torch.ones_like(disc_fake),
    )
    gen.zero_grad()
    loss_gen.backward()
    optim_gen.step()

    if batch_idx % 50 == 0:
        print(
            f'Epoch: {epoch+1}/{NUM_EPOCHS} -- Step: {steps} -- Batch: {batch_idx+1}/{len(images_loader)} -- Disc Loss: {loss_disc:.4f} -- Gen Loss: {loss_gen:.4f}'
        )
        with torch.no_grad():
            fake = gen(fixed_noise)
            fake_images = torchvision.utils.make_grid(fake, nrow=GRID_SHOW, normalize=True)
            real_images = torchvision.utils.make_grid(real[:GRID_SHOW], nrow=GRID_SHOW, normalize=True)
            writer_fake.add_image(
                'Fake', fake_images, global_step=steps
            )
            writer_real.add_image(
                'Real', real_images, global_step=steps
            )
            steps += 1

Epoch: 1/5 – Step: 1 – Batch: 1/938 – Disc Loss: 1.3863 – Gen Loss: 0.6932
Epoch: 1/5 – Step: 2 – Batch: 51/938 – Disc Loss: 1.3863 – Gen Loss: 0.6932
Epoch: 1/5 – Step: 3 – Batch: 101/938 – Disc Loss: 1.3863 – Gen Loss: 0.6932
Epoch: 1/5 – Step: 4 – Batch: 151/938 – Disc Loss: 1.3863 – Gen Loss: 0.6931
Epoch: 1/5 – Step: 5 – Batch: 201/938 – Disc Loss: 1.3863 – Gen Loss: 0.6931
Epoch: 1/5 – Step: 6 – Batch: 251/938 – Disc Loss: 1.3863 – Gen Loss: 0.6932
Epoch: 1/5 – Step: 7 – Batch: 301/938 – Disc Loss: 1.3863 – Gen Loss: 0.6931

what the actual efff !!!
I had been at the verge of a breakdown coz of this problem, suffering for more than an entire week
So the moral of story is that there was nothing wrong with pytorch routine, it was actually a GAN issue
I just had to remove the BatchNorm2d from the last layer of discriminator and there goes the loss wandering like a ‘—’
GANs are extremely sensitive to architectures and hyper-parameters