Train loss decease, but Prediction does not work in CNN Encoder-Decoder model

Issue summary

I am working on CNN Encoder-Decoder model for image prediction.
The input is 2 channel and output is 1 channel image, and MSE loss and SGD optimizer are used for the model training.
The training loss seems to be decreasing, but the prediction image has lots of trouble.

Does anyone have the clue for this issue?
I am struggling with the network design and hyper parameter again and again,
but the prediction result is alway similar trouble.
Thank you in advance.

The network details

2 Channel input is float type and 64 by 64 pixel, and the label is 1 channel, float and 64 by 64 pixel.
Each input is been encoded by CNN independently and decoded after the concatenation.
The decoder use CNN and activation function(Relu) and return regression image.
The network is as follows;

class EncoderDecoderRegression(nn.Module):
    def __init__(self):
        super().__init__()
	 # Encorder network
        self.conv_x1= nn.Sequential(nn.Conv2d(1, 8, 3, 1, 1),
                                     nn.MaxPool2d(2, 2, 0),
                                     nn.BatchNorm2d(8),
                                     nn.ReLU(),
                                     nn.Conv2d(8, 16, 3, 1, 1),
                                     nn.MaxPool2d(2, 2, 0),
                                     nn.BatchNorm2d(16),
                                     nn.ReLU(),
                                     nn.Conv2d(16, 32, 3, 1, 1),
                                     nn.MaxPool2d(2, 2, 0),
                                     nn.BatchNorm2d(32),
                                     nn.ReLU()
                                     )
        self.conv_x2 = nn.Sequential(nn.Conv2d(1, 8, 3, 1, 1),
                                     nn.MaxPool2d(2, 2, 0),
                                     nn.BatchNorm2d(8),
                                     nn.ReLU(),
                                     nn.Conv2d(8, 16, 3, 1, 1),
                                     nn.MaxPool2d(2, 2, 0),
                                     nn.BatchNorm2d(16),
                                     nn.ReLU(),
                                     nn.Conv2d(16, 32, 3, 1, 1),
                                     nn.MaxPool2d(2, 2, 0),
                                     nn.BatchNorm2d(32),
                                     nn.ReLU()
                                     )
        # Decoder network
        self.decoder = nn.Sequential(nn.Upsample(scale_factor=2),
                                     nn.Conv2d(64, 32, 3, 1, 1),
                                     nn.ReLU(),
                                     nn.Upsample(scale_factor=2),
                                     nn.Conv2d(32, 16, 3, 1, 1),
                                     nn.ReLU(),
                                     nn.Upsample(scale_factor=2),
                                     nn.Conv2d(16, 1, 3, 1, 1),
                                     nn.ReLU())

    def forward(self, x_1, x_2):
        x_1 = self.conv_x1(x1)
        x_2 = self.conv_x2(x2)
        x = torch.cat((x_1, x_2), dim=1)
        x = self.decoder(x)
        return x

Training process

model = EncoderDecoderRegression().to(device)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Train the model
epoch = 50

train_loss, val_loss = [], []
for i in tqdm(range(epoch)):
    # Train loop ----------------------------
    model.train()
    train_batch_loss = []
    for y, x  in train_loader:
        y, x = y.to(device).float(), x.to(device).float()
        x1 = x[:, 0:2, :, :]
        x2 = x[:, 2:4, :, :]
        optimizer.zero_grad()
        output = model(x1, x2)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        train_batch_loss.append(loss.item())
    # val loop ----------------------------
    model.eval()
    val_batch_loss = []
    with torch.no_grad():
        for y, x in val_loader:
            y, x = y.to(device).float(), x.to(device).float()
            x1 = x[:, 0:2, :, :]
            x2 = x[:, 2:4, :, :]
            output = model(x1, x2)
            loss = criterion(output, y)
            val_batch_loss.append(loss.item())
    # Collect loss
    train_loss.append(np.mean(train_batch_loss))
    val_loss.append(np.mean(val_batch_loss))
    print(i, "Train loss: {a:.3f}, Val loss: {b:.3f}".format(
          a=train_loss[-1], b=val_loss[-1])
          )

In my opinion, some normalization methods will be needed for labels
since the label values are quite large such as 40.0.

Also, it seems that the labels have sparse values (i.e., most regions close to zero), so you might want to use another type of loss instead of the MSE loss that induces blurry predictions.

1 Like