Put lazily created parameters on the correct device

Since you are already using an init forward pass to create the parameters and register them, you might also want to use the .device attribute of the input to move the parameters to the sme device. I don’t know how well this approach would work with a data parallel approach, but it might be fine for single GPU use cases.
Note that PyTorch also provides lazy modules via nn.Lazy* (e.g. nn.LazyLinear) so have you checked them as they might be a clean replacement for your custom approach.

1 Like