Hi, I wanted to confirm whether…:
Pytorch will call backpropagation through regular python lists? Or only, e.g. nn.ParameterList?
Is backpropagation done ONLY for parameters in (i) model.named_parameters(), (ii) model.state_dict().keys()?
How can I ensure parameters will be put through backpropagation when calling model.backwards()?
Thanks for the reply!
Initially, I had various classifier heads / linear layers stored in a python list that I would dynamically pick when calling model.forward(). So, from your reply, I suppose that the backpropagation should have worked correctly when training. However, when saving the model to use with resuming trials or inference, that the parameters for the linear layers in the list were not stored in model.state_dict() and thus cannot be used correctly?
Yes, exactly. Your question is a bit too general to answer all valid use cases.
To recap: Autograd doesn’t care where parameters are coming from and is happy to just use leaf variables as you pass them to operations.
However, if you are writing a proper
nn.Module and want to register the parameters inside it, do not use plain Python
lists as this will not register these parameters to the parent module. Use
nn.ParameterList instead which will make sure that these parameter show up in the
state_dict and will be moved to the specified
device, etc. in
Got it, thank you so much for the clarification!
On a side note, would it be necessary to turn requires_grad=True on and off if I was using a list/nn.ModuleList to slice for the classifier head I wanted per model.forward() call? Or would selection of the particular head for the forward pass already isolate it in the computation graph?
If you are not using specific parameters, they will not be included in the computation graph, so changing their
requires_grad attribute might not be necessary.
However, if you are using an optimizer with internal running states (e.g.
Adam) then parameters with a history (e.g. parameters which were previously updated) could still be updated by the
optimizer.step() even if they did not receive any new gradients.
The recommendation is thus again use case dependent.