Spatial classifier

Hi all, how do you make a spatial classifier in PyTorch Linear layers?
By spatial classifier I mean that it can take 2d inputs.

if the two dimensions are small enough (e,g MNIST pictures), you can flat your inputs and use a 1d linear layer via

1d_input = 2d_input.view(-1)

we ended up changing linear layers with 1x1 convolutions:

nn.Conv2d(4096, 4096, kernel_size=1)