So I have been experimenting with multiple CNN architectures and wanted to clear some ideas. I did the usual CNN+flattening+Linear Layers 1st and that beats the baselines but not by a large margin. Suppose I have an Image with N points. I tried the following:
 CNN> no flattening → Linear Layer to reduce the number of channels(rehspaping the input of this linear layer to have the number of channels as the last dimension to be in accordance with pytorch’s convention) → output an image with the same dimensions as the input

CNN > no flattening > Convolutional Layer to reduce the number of channels to only 1> Flattening the one channel left> Linear layers > output of N values (N is the number of pixels in the image)

CNN > no flattening > having N parallel linear layers(with in_features=num_channels at the end of the CNN) that look into the same pixel in all channels > Output N values and reshape them into an image

CNN > Grouped convolutions with kernel size=1 . This one I haven't managed to get running as intended. But it is the same idea as the architecture before. Having parallel parameters for each pixel after the initial CNN. I don't know if it is possible to do that with grouped convolutions...
And the results were:
 Model wasn’t able to learn anything. Training loss stayed constant. I checked initializations and each layer output and it doesn’t seem to be a vanishing gradients problem. Adding more parameters did not work.

Almost same results as the initial try(conventional CNN model).

Here with just one group of N parallel linear layers it managed to decrease the Loss a bit. But when I added another group of N linear layers it got very slow while training. I suspect it is because I use nested for loops to feed each linear layer the desired inputs.
I’m looking into einsum notation to maybe get rid of the for loops. Problem is: I don’t know if I can have only one weight matrix for all my linear layers and do the multiplication “by hand” with the outputs of the CNN using einsum notation to help me get the desired output for each parallel linear layer.
I have tried to use architectures that do not flatten the output of the convolutions to see if maintaining the information of spatial patterns learnt helps, because, in the problem at hand, I think it should help. But I haven’t really got good results. Is it worth not flattening, and why it isn’t(or is)?
So, what I’m asking is: what can I do better in those architectures that I’ve tried already? What new architectures do you recommend? Do you have any resources that I can read about CNNs without flattening(or why should we always flatten)?
Ps: If it helps, I’m trying to do weather forecast on a 2D spatial grid using previous time steps.
Sorry for the long post(and the bad English). Thank you for your time!