Put lazily created parameters on the correct device

I have a custom module that creates some parameters lazily. These parameters are initialized as None, but when forward(input) is called, these parameters are “created,” that is, changed to instances of the torch.nn.parameter.Parameter class. The reason I do this is because the shapes of those parameters depend on the shape of the input to the module, which I don’t know until I have used it once. (As a side note, for this reason, I also need to pass an example through the network once before creating the optimizer for the lazily created parameters to be registered in the optimizer.)

But how do I know which device to put the parameters on? I can pass a device parameter to the module’s __init__ method, but if the module, a, is a child module of another module, b, and the programmer that uses the module does b.to(other_device), then the to method of a will not be called so a will not know that all parameters should now be located on other_device rather than on device, and consequently the parameters will end up on the wrong device when created.

Any ideas for how to solve this?

1 Like

Since you are already using an init forward pass to create the parameters and register them, you might also want to use the .device attribute of the input to move the parameters to the sme device. I don’t know how well this approach would work with a data parallel approach, but it might be fine for single GPU use cases.
Note that PyTorch also provides lazy modules via nn.Lazy* (e.g. nn.LazyLinear) so have you checked them as they might be a clean replacement for your custom approach.

1 Like

Using the same device as the input to the layer seems like a valid approach; I didn’t think of that! Thank you very much!