Defining parameters in the forward pass

In PyTorch, one can define parameters in the forward method rather than in the init method (when their shape depends on the size of the inputs).

Small example (as explained in this thread):

class MyModule(nn.Module):
    def __init__(self):
        # you need to register the parameter names earlier
        self.register_parameter('weight', None)

    def forward(self, input):
        if self.weight is None:
            self.weight = nn.Parameter(torch.randn(input.size()))
        return self.weight @ input

However, how can I load the parameters of a given state dict into such a model? (i.e. suppose I have a state_dict with the ‘weight’ parameter). Currently I’m getting this:

RuntimeError: Error(s) in loading state_dict for MyModule:
	Unexpected key(s) in state_dict: "weight".

Second question: the code above will not work when moving the model to GPU, as model.cuda() is usually called before forward(). How can I make it work both on CPU and GPU?

I don’t think there is a clean approach besides performing a “warmup” iteration using your real inputs to create the parameters. Afterwards, you could push the parameters to the desired device and pass them also to the optimizer. There is a newly added meta device, but I don’t think this would help here given that the shapes are unknown. If I understand its use case correctly, it’s used to abstract the device, but not necessarily the data/shape.

Ok, so first I need to pass a dummy input to create the parameters:

# instantiate model defined above
model = MyModule()

# pass dummy input
dummy_input = torch.randn((2,2))
model(dummy_input)

Now i can load the state dict. Should I then call model.to(device), and everything will work?

Yes, once all parameters and layers are created, model.to() and the optimizer setup should work.

1 Like