self.meta_optim.zero_grad()
loss_q.backward(retain_graph=True)
# optimize the discriminator
self.meta_optim.step()
for i in range(tasks_per_batch):
qry_gen = self.generator(x_qry[i], vars=None, bn_training=True)
for p in fast_weights[i]:
p.detach()
dis_feat = self.discriminator(x_qry[i], fast_weights[i], bn_training=True, feat=True)
gen_feat = self.discriminator(qry_gen, fast_weights[i], bn_training=True, feat=True)
fm_loss = torch.mean(torch.abs(torch.mean(gen_feat, 0) - torch.mean(dis_feat, 0)))
nsample = gen_feat.size(0)
gen_feat_norm = gen_feat / gen_feat.norm(p=2, dim=1).reshape(-1, 1).expand_as(gen_feat)
cosine = torch.mm(gen_feat_norm, gen_feat_norm.t())
mask = Variable((torch.ones(cosine.size()) - torch.diag(torch.ones(nsample))).cuda())
pt_loss = 0.8 * torch.sum((cosine * mask) ** 2) / (nsample * (nsample - 1))
loss_gen = fm_loss + pt_loss
with torch.no_grad():
corrects["gen_discrim"][0] += fm_loss
self.gen_optim.zero_grad()
loss_gen.backward()
# optimize the generator
self.gen_optim.step()
The left is real images, and the right is generated images. Is my backward procedure mistaked? Or maybe other problems like model, loss function, etc.?