Initialize module only upon first input? (coding best practices)

This question is about best practices for designing subclasses of nn.Module. I am trying to decrease unnecessary coupling in my code.

For example, say my network has a Conv2d layer and a Linear layer. The hyperparameters are: size of input image, number of conv. channels, kernel size, stride, number of FC layers. When building the model, I need to know the input size to the Linear layer in advance (as described in this answer). The arguments I must pass to initialize the Linear layer are coupled to other hyperparameters: whenever I change some hyperparameters of layers earlier in the chain(for example, image size or stride) I have to (or my code has to) recalculate the value. Presumably this is why there is demand for a Utility function for calculating the shape of a conv output. Conversely, the arguments needed to initialize a Conv2d layer are not coupled to hyperparameters earlier in the chain.

As a way of reducing coupling when writing custom modules, is it acceptable to move some initialization code from __init__() to forward()? forward() would then look like the pseudocode below.

def forward(self, x):
    if not self.initialized:
        # Initialize inner modules somehow using the input size
        self.weights = nn.Parameter(torch.randn(..., x.size(2), x.size(3), ...))
        self.initialized = True

    # Go on to usual forward() function
    ...

What would be the drawbacks of this approach? Are there other approaches to reducing coupling?

One shortcoming I can think about is that you would need a random forwardpass to create the state_dict in order to pass it to the optimizer.

Also the same would apply it you save the trained state_dict, create the model in another script and try to load the state_dict.
Your model would be empty until you run a pseudo-forwardpass.

I don’t really like the idea. You could evade this by calling the forward in the __init__ method, but that would obviously make your approach useless, since you need to know the shape.

Thanks @ptrblck for pointing out that shortcoming. There might be some way to work around it by modifying load_state_dict or using placeholder values, but I suppose it will always be better to have __init__ actually initialize the module, not least because that’s what everyone expects.

Do you know of any commonly-used approaches to simplify (or make more readable) model-building code? I’m now thinking of something like a helper method which returns the output size of forward given an input size.

The aspect you would like to simplify is mostly the number of in_features in nn.Linear layers as far as I understand.
One way would be to use nn.AdaptiveMaxPool2d or just convolutions.
Of course you could try to calculate the shape of in_features using the shape formula for conv layers etc.

calculate the shape of in_features using the shape formula for conv layers etc.

Yes, I suppose I am talking about a way to do that, not specific to nn.Linear.

The motivation for my question is I have seen several implementations of experimental network architectures with frustrating (anti)patterns such as

  • magic numbers (e.g. a layer initialized with 1152 neurons as a result of unexplained calculations)
  • user-settable parameters with unnecessary coupling (e.g. a command line parameter --unit-size=1152 which cannot be freely changed in reality, because it is determined by the input image size and other hyperparameters)

As a new Pytorch user, I am wondering if there is a “best practices” way to avoid such patterns in my own code, especially when writing custom subclasses of nn.Module.

BTW, I see my “helper method” idea is really the same as Carson McNeil’s idea here.

I see the issues. A cleaner way would be to set the in_features as the shape calculations, e.g. in_features=512*7*7. So that it would be clear the last conv layer outputs 512 channels and a volume of spatial size 7x7.

This can be easily done by passing a random input to the conv layer:

nn.Conv2d(3, 6, 3, 1, 1)(torch.randn(1, 3, 24, 24)).shape

However, you would still need to know the input shape of your image.
I’m not sure there is a clean way to achieve the shape inference currently.
Maybe someone else has a good idea.

1 Like