What is the best way to declare the flatten Layer in Pytorch

In general, we have to calculate the dimension of previous layer to declare a fully connected layer(flattened layer) after performing convolutions and pooling operations.

Some forums also displayed the use of functions for the purpose like

def count_input_neuron(model, image_dim):
    return model(torch.rand(1, *(image_dim))).data.view(1, -1).size(1)

which can be seen here on Pytorch forum and Stack overflow

whereas some used flatten function of the torch

1 Like

You could either use the nn.Flatten module directly or use the functional API via torch.flatten or tensor.view in the forward method.
The approach depends on your use case (e.g. nn.Flatten should be used in an nn.Sequential container) and your coding style.