Good point. This would be better then:
class MyModule(nn.Module):
def __init__(self):
# you need to register the parameter names earlier
self.register_parameter('weight', None)
def reset_parameters(self, input):
self.weight = nn.Parameter(input.new(input.size()).normal_(0, 1))
def forward(self, input):
if self.weight is None:
self.reset_parameters(input)
return self.weight @ input
input.new
will create a new tensor of the same type as input, and it will be placed on the same GPU as input.