Multiple loss functions using nn.Module

I was working on an image restoration task and I considered multiple loss functions . My plan was to consider 3 routes:

1: Use multiple losses for monitoring but use only a few for training itself
2: Out of those loss functions that are used for training, I needed to give each a weight - currently I am specifying the weight. I would like to make that parameter adaptive.
3: If in between training - if I observe a saturation I would like to change the loss function . or its components. Currently I considered re-training a network (if in the first training the model saturated) such that it trained with a particular loss function for the the first say M epochs after which I change the loss.

  1. Except the last case I developed a code which computes these losses but I am not sure whether it will work. - ie whether it will backpropagate? (code given below)

  2. is it possible to give the weights adaptively when using combination of loss functions - ie can we train the network so that these weights are also learned ?

  3. can this implementation be used for the above mentioned case 3 of changing loss functions

Sorry if anything given here is not clear or wrong. I am kind of new into PyTorch

criterion = _criterion
#--training
prediction = model(input)
loss = criterion(prediction, target)
loss.backward()



class _criterion(nn.Module):

    def __init__(self, model_type="CNN"):

        super(_criterion).__init__()    

        self.model_type = model_type

        

    def forward(self, pred, ref):

        loss_1 = lambda x,y : nn.MSELoss(size_average=False)(x,y)       

        loss_2 = lambda x,y : nn.L1Loss(size_average=False)(x,y)        

        loss_3 = lambda x,y : nn.SmoothL1Loss(size_average=False)(x,y)  

        loss_4 = lambda x,y : L1_Charbonnier_loss_()(x,y)     #user-defined         


        if opt.loss_function_order == 1:

            loss_function_1 = get_loss_function(opt.loss_function_1)

            loss = lambda x,y: 1*loss_function_1(x,y)  

        
        elif opt.loss_function_order == 2:

            loss_function_1 = get_loss_function(opt.loss_function_1)

            loss_function_2 = get_loss_function(opt.loss_function_2)

            weight_1 = opt.loss_function_1_weight

            weight_2 = opt.loss_function_2_weight

            loss = lambda x,y: weight_1*loss_function_1(x,y) + weight_2*loss_function_2(x,y)

        elif opt.loss_function_order == 3:

            loss_function_1 = get_loss_function(opt.loss_function_1)

            loss_function_2 = get_loss_function(opt.loss_function_2)

            loss_function_3 = get_loss_function(opt.loss_function_3)

        

            weight_1 = opt.loss_function_1_weight

            weight_2 = opt.loss_function_2_weight

            weight_3 = opt.loss_function_3_weight

        

            loss = lambda x,y: weight_1*loss_function_1(x,y) + weight_2*loss_function_2(x,y) +  weight_3*loss_function_3(x,y)    

        elif opt.loss_function_order == 4:

            loss_function_1 = get_loss_function(opt.loss_function_1)

            loss_function_2 = get_loss_function(opt.loss_function_2)

            loss_function_3 = get_loss_function(opt.loss_function_3)

            loss_function_4 = get_loss_function(opt.loss_function_4)

                

            weight_1 = opt.loss_function_1_weight

            weight_2 = opt.loss_function_2_weight

            weight_3 = opt.loss_function_3_weight

            weight_4 = opt.loss_function_4_weight     

           

            loss = lambda x,y: weight_1*loss_function_1(x,y) + weight_2*loss_function_2(x,y) +  weight_3*loss_function_3(x,y)  +  weight_4*loss_function_4(x,y)       

        else:

            raise Exception("_criterion : unable to interpret loss_function_order")

        return loss(ref,pred), loss_1(ref,pred), loss_2(ref,pred), loss_3(ref,pred), loss_4(ref,pred)



def get_loss_function(loss):    

    if loss == "MSE":

        criterion = nn.MSELoss(size_average=False)

    elif loss == "MAE":

        criterion = nn.L1Loss(size_average=False) 

    elif loss == "Smooth-L1":

        criterion = nn.SmoothL1Loss(size_average=False) 

    elif loss == "Charbonnier":

        criterion = L1_Charbonnier_loss_()
    else:

        raise Exception("not implemented")
    return criterion


class L1_Charbonnier_loss_(nn.Module):

    def __init__(self):

        super(L1_Charbonnier_loss_, self).__init__()

        self.eps = 1e-6 

    def forward(self, X, Y):

        diff = torch.add(X, -Y) 

        error = self.eps*((torch.sqrt(1+((diff * diff)/self.eps)))-1)

        loss = torch.sum(error) 

        return loss
  1. You can check if your loss function implementations are correct by checking if all expected parameters get a valid .grad value after the loss.backward() call. Initially the param.grad attribute is set to None and should contain a gradient tensor after the backward operation.

  2. I’m afraid that your training routine might “learn” to achieve the best loss by setting all weights to zero or negative values. In the end, SGD is trying to lower the loss.

1 Like