Flatten all dimension excep batch (?)

I am following the CIFAR-10 tutorial (Training a Classifier — PyTorch Tutorials 1.11.0+cu102 documentation), pretty nice, I understand pretty well and so on, but there is a line in the forward method that says

x = torch.flatten(x, 1) #flaten all dimension excep batch

The comment got me confused, I always have to flatten or not, then, why, it supposed that all must go flatten

Thank you so much fellas, I am new in this journey.

The flatten layer simply converts the shape of your input data from an n-dimension matrices into a vector format that is suitable for all other densely connected layers of your neural net to interprete.

1 Like