Custom loss function with learnable weight map

I have an auto encoder model which is being trained using L1 loss. I have edited this such that the model now produces 2 outputs - the predicted image (10x10) and a weight map of the same size (10x10). The weight map is learnable ( I used a nn.parameter function with adam optimizer) and is updated based on input.

I then calculate the pixel-wise MAE (also 10x10) with respect to a reference image and then perform element-wise multiplication with the weight map to obtain weighted L1 pixel level loss. What is preventing the model from pushing the values of the weight map to all 0s? How is the model still learning to pay attention to areas in the image?

I don’t fully understand this explanation, since your model won’t be able to create parameters.
Are you rewrapping the model output into an nn.Parameter and try to optimize this activation tensor?

No, I’m not sure I did it correctly. I started with a nn parameter and passed it to the model along with my image. I forced the model to create two outputs, one predicted image and the other being my weight map (with a conv+relu layer followed by softmax) . I added the nn parameter to my Adam optimizer along with the model parameters.

Could you post a minimal and executable code snippet showing this issue, please?

Okay here you go- I have a simple autoencoder and a modified loss function.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
class ConvAutoencoder(nn.Module):
    def __init__(self):
        super(ConvAutoencoder, self).__init__()
       
        #Encoder
        self.conv1 = nn.Conv2d(1, 16, 3,padding=1)  
        self.conv2 = nn.Conv2d(16, 4, 3,padding=1)
        self.pool = nn.MaxPool2d(2)
       
        #Decoder
        
        self.t_conv1 = nn.ConvTranspose2d(4, 16, 2, stride=2)
        self.t_conv2 = nn.ConvTranspose2d(16, 1, 2, stride=2)
        #weightmapper
        self.convw=nn.Conv2d(1,16,3,padding='same')
        self.conwm2=nn.Conv2d(16,1,3,padding='same')

    def forward(self, x,wm):
        x = F.relu(self.conv1(x),inplace=True)
        x = self.pool(x)
        
        x = F.relu(self.conv2(x),inplace=True)
        x = self.pool(x)
       
        x = F.relu(self.t_conv1(x),inplace=True)
       
        x = (self.t_conv2(x))
        wm=self.convw(wm)
        wm=F.relu(self.conwm2(wm),inplace=True)
              
        return x,wm
convo=ConvAutoencoder()
optimizer = torch.optim.Adam(convo.parameters(),
                             lr = 1e-3,
                             weight_decay = 1e-8)
class weightedloss(nn.Module):
    def __init__(self):
        super(weightedloss, self).__init__()
        self.loss_fn = nn.L1Loss(reduction='none')
       
    def forward(self,input,targets,wmat):
        ls1=self.loss_fn(input,targets)
        full=torch.mean(ls1*wmat)

        return full

epochs = 20
outputs = []
weights=[]
losses = []
bmat=np.random.rand(32,32).astype('float32')
wmat=np.random.rand(32,32).astype('float32')
wmat=nn.Parameter(torch.tensor(wmat))
convo=ConvAutoencoder()
losser=weightedloss()
optimizer = torch.optim.Adam(convo.parameters(),
                             lr = 1e-3,
                             weight_decay = 1e-8)
bmat=(torch.from_numpy(bmat)).unsqueeze(0)
wmat=wmat.unsqueeze(0)
for epoch in range(epochs):
  for i in range(0,5):
    amat = np.random.rand(32,32).astype('float32')
    amat=(torch.from_numpy(amat)).unsqueeze(0)

    predict,weight = convo(amat,wmat)
    loss = losser(predict, bmat,weight)
  
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    losses.append(loss.item())
    weights.append(torch.mean(weight).item())

In this, the list named weights is collecting the mean of the weight map at each minbatch

Thanks again

Thanks for the code!

I don’t see where wmat is updated as it seems to be randomly initialized and then used in the weightedloss method.
Did you forget to pass it to an optimizer?

So I am confused between two approaches-

  1. I use random initialization for the wmat but learn the weights of the conv layers to (#weightmapper) or
  2. I treat it like a parameter and directly update it with every back-prop. I am not sure which works better.
    Also, the loss function is slightly updated now to- full=torch.mean(ls1*wmat)+(1/wmat) to prevent it from going to 0. Are there any other ways to prevent wmat becoming 0?