When A is a module with multiple sub-modules, where only a sub-set of the sub-modules are used dependent on the input, for example, A is defined as follows,
self.common = nn.Linear(100, 50)
self.module1 = nn.Linear(50,30)
self.module2 = nn.Linear(50,2)
def forward(self, input, idx):
commonOut = self.common(input)
if idx == 0:
opt = torch.optim.Adam(A.parameters())
In this case, if A is forwarded with idx=0, then the parameters belong to the module2 is not necessary to be updated.
When using opt for the parameter update, I wonder if PyTorch automatically only updates the modules that are acutally used in forward or the unused modules are also updated by adding zero gradient, which would be waste of resource.
If it does, is there any way I can get the list of parameters that are actually updated?
No, we don’t do that. There’s no way how the optimizer would know which parameters were or weren’t used, and there’s no possible way to find out about that. It would require us to impose some strict requirements, that don’t really make sense in most use cases, for only minor improvements in other. Why is it such a problem that these parameters are getting updated? Is your script slow? It shouldn’t add a lot of overhead.
Thanks for the clarification. Since my model uses only a small subset of the sub-modules dependent on the input, it could be more efficient if only the used parameters are updated.
However, since there are only parameter updates (no forward/backward passes) for the unused modules, it might not be the bottleneck of the entire process as you mentioned. (If it is a bottlneck, then I’ll use separate modules with their own parameter updates.)
If it turns out to be a bottleneck, you could create an optimizer per each module, and call
.step() only on the optimizers of used modules. That should be quite simple to implement.
Aside from performance considerations, if the error is only being evaluated in terms of the output of a single module, aren’t there other (possibly negative) implications of updating the weights for components unused in the forward pass?