Your approach would generally work, but you would have to be careful about creating the optimizer as the parameters are not fully initialized during the instantiation of the model.
E.g. this standard approach would be wrong:
model = Model()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
since model.parameters()
does not yet contain the parameters of self.linear
.
You would thus have to use:
model = Model()
out = model(dummy_input)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
The same limitation applies to the usage of Lazy*
modules as described here.
For your manual approach you would also have to take care of pushing the module to the device
after the first forward pass, which should be done automatically using the Lazy*
modules.