Regression model only gives a single output

I’m having trouble getting my network to learn - it always seems to predict the same result regardless of input.

The goal of the network is to interpret 4D data and return a 2D array. The data is a simulation of a particular setup for laser-speckle imaging. It was simulated for a number of different starting positions, and 3 simulated experimental settings (in brief). This results in data with a [5, 6, 21, 22] shape. Simulations are done on 2D [21, 21] arrays, resulting in unique simulation data based on the 2D array given as input. This is the same object the network should predict.

My approach was to use convolutional layers followed by fully connected layers. This is the same approach I used when our problem was predicting 1D arrays from 3D data (this worked nearly perfectly). I’m assuming something went wrong when upscaling a dimension, or maybe someone has an idea of what might be going wrong.

My forward function:

    def forward(self, x):
        x = F.relu(self.convs(x)) # Uses convNd function from https://github.com/pvjosue/pytorch_convNd
        x = x.view(-1, x.shape[0], x.shape[1]*x.shape[2]*x.shape[3]*x.shape[4]*x.shape[5]) #flatten
        x = self.max1d(x.permute(0,1,2)).permute(0,1,2) # pooling nn.MaxPool1d(2, 2)
        x = self.max1d(x.permute(0,1,2)).permute(0,1,2) # pooling nn.MaxPool1d(2, 2)
        x = F.relu(self.fc1(x)) # nn.Linear(self._to_linear,1024), _to_linear = torch.flatten(x).shape[0] after convs
        x = F.relu(self.fc2(x)) # nn.Linear(1024,21*21)
        x = self.fc3(x) # nn.Linear(21*21,21*21)
        x = x.view(-1,21,21)
        return x

When I use a single data point for training, the network is able to give that output consistently. However, when I scale up to 8 data points (something which should easily be doable without any generalized learning), it is uncapable of giving different results while the loss is no longer decreasing. Images below are an example of what input/output looks like.
op3

Does anyone have an idea of what might be going wrong here? If I need to give more information/code, please let me know!