Differentiate through SGD Step

I need the gradient to flow back through a previous update step but the SGD step method is decorated with the no_grad property and executes the operation in-place.

I tried to do that manually but if I use tensor.add_ to modify the model parameter I get RuntimeError: a leaf Variable that requires grad is being used in an in-place operation

How can I do that?

Hi,

If you do new_param = old_param + lr * grad then it will be properly differentiable.

Note that if you try ot work with torch.nn though, the nn.Parameter can only be leafs (and so cannot have history). So you won’t be able to use them here.

1 Like

Hi, thanks for replying :smile:
Ok, it’s clear but let’s say I’ve defined a ConvNet (because that’s what I did) using some torch.nn classes (Conv2d, etc.), how can I do that?

I figure I can modify all the parameters just by iterating model.parameters() and doing the upgrade on each of them, but then how do I “put” them inside a functioning model?

I feel like there should be an easy solution because everyone who’s doing meta-learning would face the same problem I’m facing right now (I may be wrong, I don’t actually know much about meta-learning but I know optimization-based algorithm often involved differentiating through an inner training).

Hi,

Unfortunately there isn’t as the nn.Module is by design, not functional.
I woudl advise using a library like https://github.com/facebookresearch/higher that handles all of that for you (and provide differentiable version of the pytorch optimizers as well).

Thank you, I’ll check that out. I guess that’s what I actually need.
Just some other questions out of curiosity:

  1. What does functional mean in this context?
  2. Why isn’t nn.Module functional?
  3. Will this problem be addressed in a future release?
  • What does functional mean in this context?

It means that you don’t have side effects. In particular, you take the parameters as an input to the function instead of automatically capturing them from the state of the nn.Module

  • Why isn’t nn.Module functional?

Because when you do mod(inp) it uses a lot of Parameters that are not in the function inputs.

  • Will this problem be addressed in a future release?

This is something we have in mind yes: transform any nn.Module into a function that takes the Parameters as input. But there are quite a few details to clarify there and it is not done yet.

1 Like