Optimizer parameters not updating

My code performs a loss minimization from the output of a pretrained GAN. Apart from the GAN parameters I am also trying to optimize the latent variable z.

The code is set up as follows

with torch.no_grad():
   gan_out = G(z_initial)

z_initial.retain_grad = True
loss = MSEloss(gan_out, observed_image)

GAN_optim.zero_grad()
z_optim.zero_grad()
loss.backward()
GAN_optim.step()
z_optim.step()

The optimizers are defined as

GAN_optim = optim.RMSprop(G.parameters(), lr = 3e-5)
z_optim = optim.RMSprop([z_initial], lr = 1e-3)

# where z_initial is 

z_initial = torch.Variable(z_.clone().to(device), requires_grad = True)

After training, I am visualizing z_initial and the GAN weights. While the GAN weights are updating and changing z_initial is not.

The tensorboard visualization is done for every epoch as follows

writer.add_histogram('Z_prior', z_initial)

How can I make sure z_initial updates?

How about using nn.Parameter instead of Variable?

z_initial = nn.Parameter(z_).to(device)

Tried that and had the same issue! Not sure what else I should do

Your code shouldn’t work at all, since you are wrapping the forward pass into a no_grad guard:

G = nn.Linear(1, 1)
z_initial = torch.randn(1, 1)

with torch.no_grad():
   gan_out = G(z_initial)

z_initial.retain_grad = True
loss = nn.MSELoss()(gan_out, torch.rand_like(gan_out))
loss.backward()
# > RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Could you explain the use case a bit and why you are disabling Autograd while trying to calculate the gradients?

Yeah I think you are right, I was considering that to be a problem. I require the generator to be frozen to get gan_out. That’s why I was disabling autograd.

Will using G.no_grad() work instead?

No, no_grad() is not a class method of nn.Module.
Are you trying to update the input to G only while G is frozen?
If so, this should work:

G = nn.Linear(1, 1)
# freeze
for param in G.parameters():
    param.requires_grad = False

z_initial = torch.randn(1, 1, requires_grad=True)

gan_out = G(z_initial)

loss = nn.MSELoss()(gan_out, torch.rand_like(gan_out))
loss.backward()
print(z_initial.grad)
# > tensor([[0.5429]])
print(G.weight.grad)
# > None

Yeah my use case is specifically keeping G frozen and generating samples from it - it is a test time application.

I will try this method, thanks!

Another thing -

After this loss.backward() call, I need to modify z_initial to clamp it like so:

z_initial = torch.clamp(z_initial, -1., 1.)

I am not able to do this and I am getting an error saying the code is trying to traverse the graph twice after some variables have been freed. I understand the error since it makes sense.

However, once I call the optimizer for z, I need to clamp it. Is there any other way this can be done?

You might want to clamp the tensor inplace to avoid having to pass it to the optimizer repeatedly as seen here:

G = nn.Linear(1, 1)
# freeze
for param in G.parameters():
    param.requires_grad = False

z_initial = torch.randn(1, 1, requires_grad=True)
optimizer = torch.optim.Adam([z_initial], lr=1e-3)

for epoch in range(10):
    optimizer.zero_grad()
    gan_out = G(z_initial)

    loss = nn.MSELoss()(gan_out, torch.rand_like(gan_out))
    loss.backward()
    print('epoch {}, loss {}, z_initial {}, .grad {}'.format(
        epoch, loss.item(), z_initial, z_initial.grad))
    optimizer.step()
    with torch.no_grad():
        torch.clamp_(z_initial, -1., 1.)