UNET predictions looks like input, but not look like target

Hello, I am trying to train a U-NET model with Sea Surface Temperature data. The minimum and maximum values are -2 and 32. Input and Target data shapes are 32x32.

Data Loading

# Load your data as PyTorch tensors and split into training and validation sets
X = torch.tensor(input_sst_array, dtype=torch.float32)
X = X.unsqueeze(1)
y = torch.tensor(target_sst_array, dtype=torch.float32)
y = y.unsqueeze(1)

# Normalization the data
X_normalized = (X - 15) / 17
y_normalized = (y - 15) / 17

print(X_normalized.shape, y_normalized.shape)

# Create the dataset
dataset = MyDataset(X_normalized, y_normalized)

# Set a fixed random seed for reproducibility
seed = 42
torch.cuda.manual_seed_all(seed)  # If you are using CUDA

# Split the dataset into training, validation, and test sets
train_len = int(0.8 * len(dataset))
val_len = int(0.1 * len(dataset))
test_len = len(dataset) - train_len - val_len
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_len, val_len, test_len])

# Hyperparameters
num_epochs = 100
patience = 15
config_no = 0
learning_rate = 0.001
batch_size = 256

# Create data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

# Input Shape: torch.Size([131158, 1, 32, 32]), Target Shape: torch.Size([131158, 1, 32, 32])

Model Architecture

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        # Contracting path
        self.enc_conv1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),  # Batch normalization added
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),  # Batch normalization added
        self.pool1 = nn.MaxPool2d(2)

        self.enc_conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),  # Batch normalization added
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),  # Batch normalization added
        self.pool2 = nn.MaxPool2d(2)

        self.enc_conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),  # Batch normalization added
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),  # Batch normalization added
        self.pool3 = nn.MaxPool2d(2)

        # Expanding path
        self.up1 = nn.ConvTranspose2d(256, 256, kernel_size=2, stride=2)
        self.dec_conv1 = nn.Sequential(
            nn.Conv2d(512, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),  # Batch normalization added
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),  # Batch normalization added

        self.up2 = nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2)
        self.dec_conv2 = nn.Sequential(
            nn.Conv2d(256, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),  # Batch normalization added
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),  # Batch normalization added

        self.up3 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)
        self.dec_conv3 = nn.Sequential(
            nn.Conv2d(128, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),  # Batch normalization added
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),  # Batch normalization added

        # Output layer
        self.output = nn.Conv2d(32, 1, kernel_size=1)

    def forward(self, x):
        # Contracting path
        c1 = self.enc_conv1(x)
        p1 = self.pool1(c1)

        c2 = self.enc_conv2(p1)
        p2 = self.pool2(c2)

        c3 = self.enc_conv3(p2)
        p3 = self.pool3(c3)

        # Expanding path
        u1 = self.up1(p3)
        cat1 = torch.cat((u1, c3), dim=1)
        dc1 = self.dec_conv1(cat1)

        u2 = self.up2(dc1)
        cat2 = torch.cat((u2, c2), dim=1)
        dc2 = self.dec_conv2(cat2)

        u3 = self.up3(dc2)
        cat3 = torch.cat((u3, c1), dim=1)
        dc3 = self.dec_conv3(cat3)

        # Output layer
        out = self.output(dc3)

        return out

# Create the model and print its summary
model_unet = UNet().to(mps_device)
_ = summary(model_unet, (1, 32, 32))

Loss Functions

alpha = 0.5  # Weight for SSIM loss
beta = 0.5   # Weight for L1 loss

# Define l1 loss
def mae_loss(pred, target):
    return F.l1_loss(pred, target)

# Define l2 loss
def mse_loss(pred, target):
    return F.mse_loss(pred, target)

ssim_loss = SSIM(data_range=1.0, size_average=True, channel=1)  # Adjust the channel according to your input

# Assuming loss_functions is defined as previously mentioned:
loss_functions = {
    'L1': ('L1', lambda scores, targets: mae_loss(scores, targets)),
    'L2': ('L2', lambda scores, targets: mse_loss(scores, targets)),
    'SSIM': ('SSIM', lambda scores, targets: 1 - ssim_loss(scores, targets)),
    'SSIM+L1': ('SSIM+L1', lambda scores, targets: alpha * (1 - ssim_loss(scores, targets)) + beta * mae_loss(scores, targets)),
    'SSIM+L2': ('SSIM+L2', lambda scores, targets: alpha * (1 - ssim_loss(scores, targets)) + beta * mse_loss(scores, targets)),

# Generate all combinations of configurations
experiment_configs = []
for loss_name, (description, loss_func) in loss_functions.items():
    config = {
        'loss_func': loss_func,
        'loss_name': description  # Include the loss description for easier tracking


Training Loop

# Define the directory path
model_dir = 'denoising_v1'

# Check if the directory exists, and create it if it doesn't
if not os.path.exists(model_dir):

for config in experiment_configs:
    print("config_no:", config_no)
    model_unet = UNet().to(mps_device)
    optimizer = optim.Adam(model_unet.parameters(), lr=learning_rate)
    loss_function = config['loss_func']
    scheduler = StepLR(optimizer, step_size=15, gamma=0.1)
    # train_loader = DataLoader(dataset=train_dataset, batch_size=config['batch_size'], shuffle=True)
    # val_loader = DataLoader(dataset=val_dataset, batch_size=config['batch_size'], shuffle=False)
    best_val_loss = float('inf')
    best_train_loss = float('inf')
    best_epoch = -1
    epoch_train_losses = []
    epoch_val_losses = []
    for epoch in range(num_epochs):
        running_loss = 0.0
        for batch_idx, (data, targets) in enumerate(train_loader):
            data, targets = data.to(mps_device), targets.to(mps_device)
            scores = model_unet(data)

            # print(scores.dtype, scores.shape) 
            # print(targets.dtype, targets.shape) 
            # print(loss_function)  # Check the type of your loss function
            loss = loss_function(scores, targets) 

            loss = loss_function(scores, targets)
            running_loss += loss.item()
        epoch_loss = running_loss / len(train_loader)
        val_loss = 0.0
        with torch.no_grad():
            for data, targets in val_loader:
                data, targets = data.to(mps_device), targets.to(mps_device)
                scores = model_unet(data)
                val_loss += loss_function(scores, targets).item()
        val_loss /= len(val_loader)
        print(f"Epoch: {epoch+1}/{num_epochs}, Training Loss: {loss.item():.10f}, Validation Loss: {val_loss:.10f}")
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_train_loss = epoch_loss
            best_epoch = epoch
            best_model_state = model_unet.state_dict()
            early_stopping_counter = 0
            early_stopping_counter += 1
        if early_stopping_counter >= patience:
            print("Early stopping triggered for configuration:", config)
    print(f"Config_no: {config_no}, Loss Function={config['loss_name']}, Best Input Loss ={best_train_loss:.4f},  Best Validation Loss={best_val_loss:.4f} at Epoch {best_epoch}")
    config_no += 1
    # Construct the plot file path
    plot_file_path = os.path.join(model_dir, f"unet_kaggle_Plot_Loss_unet_config_{config['loss_name']}.png")

    plt.figure(figsize=(10, 4))
    plt.plot(epoch_train_losses, label="Training Loss")
    plt.plot(epoch_val_losses, label="Validation Loss")
    plt.axvline(x=best_epoch, color='r', linestyle='--', label="Best Epoch")
    plt.title(f"Tr and Val Loss (Loss: {config['loss_name']}, Tr Loss {best_train_loss:.4f} and Val Loss {best_val_loss:.4f} at Epoch {best_epoch})")

    loss_file_path = os.path.join(model_dir, f"Loss_{model_dir}_{config['loss_name']}_best_epoch_{best_epoch}.txt")
    with open(loss_file_path, 'w') as file:
        file.write('Epoch,Training Loss,Validation Loss\n')
        for epoch in range(len(epoch_train_losses)):

    model_path = os.path.join(model_dir, f"{model_dir}_{config['loss_name']}.pth")
    torch.save(best_model_state, model_path)

Here, I have attached loss plots for L1 and L2. Additionally, I have included two sets of predictions from two models trained with different Loss Functions L1 and L2.

Couple of questions to garner some intuition:

  1. Why calculate the loss twice in the loop?

  2. What is your objective? Is it segmentation? If so, how many classes exactly?

Hello, thank you for your reply.

  1. Repeating loss calculation is a mistake. However, I don’t think that is making any change?
  2. This is not a segmentation task. This is kind of regression problem where my intention is to convolve input image.

You’re trying to do image2image regression? Not sure what your target is here.

Yes, image2image regression.

Hi Soumya, if you think U-NET is not perfect for this kind of image2image regression, would you please suggest a few models that I might try?

I unfortunately have no experience in this. Though This may be worth while?

Regarding your training: quite confusing why its actually converging. Everything looks okay to me bar the config_no update. Not sure how much of a difference that is making. I think the formulation of your loss may be faulty? Again not sure.

Let’s also talk about why you are getting the results you are getting and why this is most likely due to MSE and MAE.

MSE and MAE calculate as a generalization. What you have essentially done with this model is tell it to be more like the labels but in the way in which the features across those labeled images are most generalized.

This is why for all of the predictions you can see that the dark blue circles are more like the labels than the inputs. This is most likely due to the fact that in your labeled image data the dark blue spots all have the same pattern (a dark blue center that gets lighter as you move away from the center).

So, you do not have a UNet training issue, you have an issue that has to do with the consistency of your labeled data. The model is generalizing the input to be most like what the output is, and because there isn’t anything in the input to give it an idea of how to accurately change the yellows and reds into a consistent pattern it results to giving the most generalized pattern.

Hope this helps, feel free to ask questions.

Also, if you have other data about the Sea Surface Temperature data outside of these images, I highly recommend trying to find a way to include it.

That data is most likely very valuable in determining these turbulent features.

Thank you Soumya. I am not sure if I made any mistake in the loss calculation. Also, I tried every possible way I found on the internet as a suggestion.

Hello Mycul, thank you for your detailed reply. I am confused about the labeling you have mentioned above. Do you have any suggestions that I might try in any steps(Data preprocessing, model architecture changing, or anything in the training)?

Thanks again!

I would not agree with this:

This is most likely due to the fact that in your labeled image data the dark blue spots all have the same pattern (a dark blue center that gets lighter as you move away from the center).

Just look at the third image. It’s more about the gradient from the edge.

The model is generalizing the input to be most like what the output is,

What do you mean by this? Also, both mae and mse factor in the difference between input and output (the general idea of a loss) so how could it generalise to the input because of that?

What i mean by this is that there is not a consistent enough pattern in the labeled data that can be accurately mapped from the input data.

I think the third image also represents this point (but is a bit harder to describe, but I’ll try my best). Essentially in the labeled outputs the dynamic patterns that describe this data are more consistent in the blues than in the yellows and reds(the blues cover less dynamic/ more generalized patterns).

This also makes sense from a physics-based perspective when analyzing sea storm temperatures. High temperature areas are higher energy areas and thus have more turbulent behavior (which creates dynamic patterns) while low temperature areas (which have less movement of particles) result in lower energy areas that result in less turbulent areas (which creates simpler patterns).

Once again… this is how I am seeing the issue. If you see it in another way, I would love to hear your analysis.

Well, you can’t really relabel the data because it’s just higher-grade weather data from the looks of it.

I would suggest another model, adding in other aspects of the data if you have them, or maybe doing some type of multi-task learning to attempt to teach the model that this is more than just an image but is representative of physics that are happening.

Maybe doing some type of keypoint detection where you’re indicating the ‘coldest’ and ‘hottest’ areas within a certain threshold and getting your models to predict those areas as well as the output image might be beneficial? This wouldn’t really even give the model an idea of the way that everything flows though so I don’t know

A potential solution might be to relabel the data into certain thresholds and create masks of these thresholds for the model to predict along with also predicting the image (penalizing the model with both of these objectives in mind).

So essentially make a mask of the range of temperature values from 100 degrees - 90 degrees, do this for every aspect of the range. Then have the model predict these masks.

This might give it a better idea of how the temperatures interact with each other based on their shape and positions to other color values. The issue with your current approach is that you aren’t expressing what the problem is well enough to the model. It just thinks of these images as pixel colors, it doesn’t realize they are all connected together, you have to give it a pattern to understand this.