Why can't my gan's implementation generate images like real ones?

    # optimize the discriminator

    for i in range(tasks_per_batch):
        qry_gen = self.generator(x_qry[i], vars=None, bn_training=True)
        for p in fast_weights[i]:
        dis_feat = self.discriminator(x_qry[i], fast_weights[i], bn_training=True, feat=True)
        gen_feat = self.discriminator(qry_gen, fast_weights[i], bn_training=True, feat=True)
        fm_loss = torch.mean(torch.abs(torch.mean(gen_feat, 0) - torch.mean(dis_feat, 0)))

        nsample = gen_feat.size(0)
        gen_feat_norm = gen_feat / gen_feat.norm(p=2, dim=1).reshape(-1, 1).expand_as(gen_feat)
        cosine = torch.mm(gen_feat_norm, gen_feat_norm.t())
        mask = Variable((torch.ones(cosine.size()) - torch.diag(torch.ones(nsample))).cuda())
        pt_loss = 0.8 * torch.sum((cosine * mask) ** 2) / (nsample * (nsample - 1))

        loss_gen = fm_loss + pt_loss
        with torch.no_grad():
            corrects["gen_discrim"][0] += fm_loss

    # optimize the generator


The left is real images, and the right is generated images. Is my backward procedure mistaked? Or maybe other problems like model, loss function, etc.?

I’m not familiar with your code, but some minor issues you should check:

  • p.detach() won’t do anything, as it’s not the inplace version (p.detach_() would be the inplace method)
  • Variables are deprecated since PyTorch 0.4 so you can use tensors in newer versions
  • Are you sure you need retain_graph=True? Sometimes this line is added as a workaround for valid errors, which will not fix the actual error.

Thanks vey much for your reply. I will try to fix some problems you mentioned.
If not using retain_graph=True, the program will error.