How to automatically get in-features from nn.conv2d to nn.linear

Your approach would generally work, but you would have to be careful about creating the optimizer as the parameters are not fully initialized during the instantiation of the model.
E.g. this standard approach would be wrong:

model = Model()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

since model.parameters() does not yet contain the parameters of self.linear.
You would thus have to use:

model = Model()
out = model(dummy_input)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

The same limitation applies to the usage of Lazy* modules as described here.

For your manual approach you would also have to take care of pushing the module to the device after the first forward pass, which should be done automatically using the Lazy* modules.

1 Like