I have a custom module that creates some parameters lazily. These parameters are initialized as None
, but when forward(input)
is called, these parameters are “created,” that is, changed to instances of the torch.nn.parameter.Parameter
class. The reason I do this is because the shapes of those parameters depend on the shape of the input to the module, which I don’t know until I have used it once. (As a side note, for this reason, I also need to pass an example through the network once before creating the optimizer for the lazily created parameters to be registered in the optimizer.)
But how do I know which device to put the parameters on? I can pass a device
parameter to the module’s __init__
method, but if the module, a
, is a child module of another module, b
, and the programmer that uses the module does b.to(other_device)
, then the to
method of a
will not be called so a
will not know that all parameters should now be located on other_device
rather than on device
, and consequently the parameters will end up on the wrong device when created.
Any ideas for how to solve this?