I’m having troubles computing gradients, here’s my setup:
I have a parametric function
f that I wan to optimize (e.g. a MLP), but the loss computation involves an auxiliary function
y = f(x)
loss = g(y)
And I’d like to compute d loss / df.
I’m sure using .detach() or autograd.grad() this should be easy but I’m not sure how to do it.
Thanks a lot !
Hi, could you please elaborate what your auxiliary function looks like.
Nonetheless, since loss is a function of y which in turn means it is a function of the parameters of f, you should be fine creating an optimizer with f’s parameters and using the loss to backpropagate the errors.
Does that make sense?
Also, could you please show some concrete code (rather than pseudo)?
Thanks for the reply !
I’m working with a frozen autoencoder. Basically what I’m doing is
- translate a latent vector
- decode the translated vector
- rotate the decoded signal by many angles
- encode back each rotated object
- each latent vector obtained contributes to the loss of the translation function
So things look like that
z_new = model(z_current, Z_gt) # function to optimize
x = decode(z_new) # decoder should be frozen
X = torch.stack([rotate(x, a) for a in theta])
Z = self.encode(X) # encoder should be frozen
loss = some_distance_function(Z, Z_gt)
What you said definitely makes sense, but I’m not sure how to freeze the autoencoder: if I use
with torch.no_grad() then
Z and thus the loss ends up being grad less. So could I simply iterate over the AE parameters and set
requires_grad = False ?
Looping over a layer’s parameters and setting their
requires_grad=False is one way like you mentioned and has worked fine for me in the past.
I think you should be fine using it.
Any specific errors you anticipate or encounter with this?
Yeah I tried and it works, thank you !
I initially encountered the
RuntimeError: Trying to backward through the graph a second time, because I kept the variable while looping the optimization, but as I can discard the gradients between steps, I got away with:
z = z.detach().requires_grad_()
Thanks again !
Good to know it’s solved.
Regarding the error:
Was it something like z was a part of some graph that you’d already backward propagated through, and you then were trying to backward propagate through another graph that also had z as part (which would essentially mean there’s no another graph, it got built as a part of the already existing graph)?
Yes that’s right, I’m updating z at each optimisation step so that’s what happend. But as I said I can discard the graph from one step to the other.