Dynamic parameter declaration in forward function

Good point. This would be better then:

class MyModule(nn.Module):
    def __init__(self):
        # you need to register the parameter names earlier
        self.register_parameter('weight', None)

    def reset_parameters(self, input):
        self.weight = nn.Parameter(input.new(input.size()).normal_(0, 1))    

    def forward(self, input):
        if self.weight is None:
            self.reset_parameters(input)
        return self.weight @ input

input.new will create a new tensor of the same type as input, and it will be placed on the same GPU as input.

9 Likes