Discriminator Loss goes to 0 quickly during Adversarial Training for Semantic Segmentation

I am incorporating Adversarial Training for Semantic Segmentation from Adversarial Learning for Semi-Supervised Semantic Segmentation.

The idea is like this:
The discriminator takes as input a probability map (21x321x321) over 21 classes (PASCAL VOC dataset) and produces a confidence map of size 2x321x321. (a real/fake decision for each pixel). For an input image (3x321x321), the segmentation network (generator) produces ‘fake’ probability map (21x321x321). The ‘real’ probability map comes from the ground truth segmentation labels (21x321x321 using one-hot encoding). The Loss for both Generator and Discriminator are easier to understand through the code.

With my current implementation, the discriminator loss quickly goes to 0, which is a failure mode for GAN training(mentioned here). I am new to pytorch (as well as to adversarial training). So, I am not sure if there’s an issue with my network architecture, hyperparameters or simply my training scheme. I would really appreciate some pointers.

This is how my training looks:

[1][0] LD: 1.402573823928833 LG: 3.0447137355804443
[1][1] LD: 0.8725658655166626 LG: 2.544170618057251
[1][2] LD: 0.6969347596168518 LG: 2.1177046298980713
[1][3] LD: 0.611475944519043 LG: 1.7778557538986206
[1][4] LD: 0.49319764971733093 LG: 2.1366050243377686
[1][5] LD: 0.30195319652557373 LG: 1.7873120307922363
[1][6] LD: 0.14412544667720795 LG: 1.045764684677124
[1][7] LD: 0.04816107824444771 LG: 1.5864180326461792
[1][8] LD: 0.012304163537919521 LG: 1.370680332183838
[1][9] LD: 0.0035684951581060886 LG: 1.3428194522857666
[1][10] LD: 0.0011156484251841903 LG: 1.145486831665039
[1][11] LD: 0.00045744137605652213 LG: 1.371126651763916
[1][12] LD: 0.0001588731538504362 LG: 1.378540277481079
[1][13] LD: 4.844377326662652e-05 LG: 1.504058837890625
[1][14] LD: 3.028669743798673e-05 LG: 1.5484118461608887
[1][15] LD: 1.5183023606368806e-05 LG: 1.584553837776184
[1][16] LD: 2.0302868506405503e-05 LG: 1.4818311929702759
[1][17] LD: 1.0679158549464773e-05 LG: 1.2976796627044678
[1][18] LD: 1.5313835319830105e-06 LG: 1.2631664276123047
[1][19] LD: 4.273606009519426e-06 LG: 1.770961046218872
[1][20] LD: 1.1575384633033536e-06 LG: 1.5112217664718628
[1][21] LD: 2.138318961897312e-07 LG: 1.2034248113632202
[1][22] LD: 1.0100056897499599e-06 LG: 1.581740140914917
[1][23] LD: 5.6876764631397236e-08 LG: 1.0763123035430908
[1][24] LD: 1.475878548262699e-07 LG: 1.6125952005386353
[1][25] LD: 6.919402721905499e-07 LG: 1.6719598770141602
[1][26] LD: 1.3498377526843797e-08 LG: 1.1914349794387817
[1][27] LD: 1.3576584301233652e-08 LG: 1.1994632482528687
[1][28] LD: 3.087819067104647e-08 LG: 1.2909866571426392
[1][29] LD: 3.416153049329296e-07 LG: 2.143049478530884
[1][30] LD: 4.477038118011478e-08 LG: 1.7709745168685913
[1][31] LD: 2.1782324832742006e-09 LG: 1.2023413181304932
[1][32] LD: 1.0589346999267946e-07 LG: 1.4242452383041382

My training code is this:

generator = deeplabv2.Res_Deeplab()
optimizer_G = optim.SGD(filter(lambda p: p.requires_grad, \
        generator.parameters()),lr=0.00025,momentum=0.9,\
        weight_decay=0.0001,nesterov=True)
 discriminator = Dis(in_channels=21)
optimizer_D = optim.Adam(filter(lambda p: p.requires_grad, \
            discriminator.parameters()),lr=0.0001,weight_decay=0.0001)

for epoch in range(args.start_epoch,args.max_epoch+1):
  
     for batch_id, (img,mask,ohmask) in enumerate(trainloader):
          img,mask,ohmask = Variable(img.cuda()),Variable(mask.cuda(),requires_grad=False),\
                                Variable(ohmask.cuda(),requires_grad=False)
        # ohmask : mask (HxW) converted to one-hot encoded probability map for each class( 21xHxW )
         out_img_map = generator(img)
         out_img_map = nn.LogSoftmax()(out_img_map)
        #######################
        # Adverarial Training#
        #######################
        if args.mode == 'adv':

                N = out_img_map.size()[0]
                H = out_img_map.size()[2]
                W = out_img_map.size()[3]

                # Generate the Real and Fake Labels
                target_fake = Variable(torch.zeros((N,H,W)).long().cuda(),requires_grad=False)
                target_real = Variable(torch.ones((N,H,W)).long().cuda(),requires_grad=False)
                              
                #########################
                # Discriminator Training#
                #########################

                # Train on Real
                conf_map_real = nn.LogSoftmax()(discriminator(ohmask.float()))

                optimizer_D.zero_grad()

                LD_real = nn.NLLLoss2d()(conf_map_real,target_real)
                LD_real.backward()

                # Train on Fake
                conf_map_fake = nn.LogSoftmax()(discriminator(Variable(out_img_map.data)))
                LD_fake = nn.NLLLoss2d()(conf_map_fake,target_fake)
                LD_fake.backward()

                # Update Discriminator weights

                optimizer_D.step()

                ######################
                # Generator Training #
                #####################
                conf_map_fake = nn.LogSoftmax()(discriminator(out_img_map))
                LG_ce = nn.NLLLoss2d()(out_img_map,mask)
                LG_adv = args.lam_adv * nn.NLLLoss2d()(conf_map_fake,target_real)

                LG_seg = LG_ce + args.lam_adv * LG_adv
                optimizer_G.zero_grad()
                LG_ce.backward(retain_variables=True)
                LG_adv.backward()
                optimizer_G.step()
                print("[{}][{}] LD: {} LG: {}".format(epoch,i,(LD_real + LD_fake).data[0],LG_seg.data[0]))

I am using Resnet-101 as segmentation network (which is my generator) and my discriminator is as follows:

   class Dis(nn.Module):
    """
        Discriminator Network for the Adversarial Training.
    """
    def __init__(self,in_channels,negative_slope = 0.2):
        super(Dis, self).__init__()
        self._in_channels = in_channels
        self._negative_slope = negative_slope

        self.conv1 = nn.Conv2d(in_channels=self._in_channels,out_channels=64,kernel_size=4,stride=2,padding=2)
        self.relu1 = nn.LeakyReLU(self._negative_slope,inplace=True)
        self.conv2 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=4,stride=2,padding=2)
        self.relu2 = nn.LeakyReLU(self._negative_slope,inplace=True)
        self.conv3 = nn.Conv2d(in_channels=128,out_channels=256,kernel_size=4,stride=2,padding=2)
        self.relu3 = nn.LeakyReLU(self._negative_slope,inplace=True)
        self.conv4 = nn.Conv2d(in_channels=256,out_channels=512,kernel_size=4,stride=2,padding=2)
        self.relu4 = nn.LeakyReLU(self._negative_slope,inplace=True)
        self.conv5 = nn.Conv2d(in_channels=512,out_channels=2,kernel_size=4,stride=2,padding=2)

    def forward(self,x):
        x= self.conv1(x) # -,-,161,161
        x = self.relu1(x)
        x= self.conv2(x) # -,-,81,81
        x = self.relu2(x)
        x= self.conv3(x) # -,-,41,41
        x = self.relu3(x)
        x= self.conv4(x) # -,-,21,21
        x = self.relu4(x)
        x = self.conv5(x) # -,-,11,11
        # upsample
        x = F.upsample_bilinear(x,scale_factor=2)
        x = x[:,:,:-1,:-1] # -,-, 21,21

        x = F.upsample_bilinear(x,scale_factor=2)
        x = x[:,:,:-1,:-1] # -,-,41,41

        x = F.upsample_bilinear(x,scale_factor=2)
        x = x[:,:,:-1,:-1] #-,-,81,81

        x = F.upsample_bilinear(x,scale_factor=2)
        x = x[:,:,:-1,:-1] #-,-,161,161

        x = F.upsample_bilinear(x,scale_factor=2)
        x = x[:,:,:-1,:-1] # -,-,321,321

        return x

Thanks in advance!