Pytorch not optimizing my model

Hi all,

I have trained a Generator named netG (which is fixed in the following code). Given x, then I tried use adam optimizer to optimize the following function to get the z:

class loss(nn.Module):
    def __init__(self, nz):
        super().__init__()
        self.z = nn.Parameter(torch.zeros(1, 10, 1, 1),requires_grad=True)
        self.criterion = nn.BCELoss()
    def forward(self,x,netG, netD):
        l_rec = torch.norm(x-netG(self.z))
        print('L_rec:',l_rec.detach().numpy())
        return l_rec

The optimization process is:

#Number of repetitions
epoch = 1000
loss_epoch = np.zeros(epoch)
for i in range(epoch):
    #Gradient initialization used in optimizer
    opt.zero_grad()
    lss = model(x, netG, netD)
    loss_epoch[i] = lss.item()
    #Gradient setting
    lss.backward()
    #Perform optimization
    opt.step()

But the result shows my z doesn’t update (loss not decreasing), the loss is showing in the following figure


Can anyone help me?
Thanks!

1 Like

Can you try printing self.z[0][0][0][0] in forward(), to see if it becomes non-zero?

Yes, I printed it:


It is zero all the time.

Is this expected behaviour? If this tensor stays at all zeroes, will your loss decrease? Or does your loss decreasing depend on this tensor changing?

No, I am expecting the this z should not be zero all the time. My loss decreasing does depend on this tensor. That is why I am expecting the loss to decrease.

In that case you have to write code in your netG class to ensure that gradients on z are propagated (I am not sure if this is the correct term; what I mean is that you should tell PyTorch how gradients with respect to z change, when you call netG.forward(z).).

This is not something that I have done (yet!), so I am not entirely clear on how to do this. I think you need to implement a backward function inside your netG class. The code in this message may help you get started.

Thanks! I tried to initialize z as torch.rand(1, 10, 1, 1) and then it works. But don’t know why.