I was working on an image restoration task and I considered multiple loss functions . My plan was to consider 3 routes:
1: Use multiple losses for monitoring but use only a few for training itself
2: Out of those loss functions that are used for training, I needed to give each a weight  currently I am specifying the weight. I would like to make that parameter adaptive.
3: If in between training  if I observe a saturation I would like to change the loss function . or its components. Currently I considered retraining a network (if in the first training the model saturated) such that it trained with a particular loss function for the the first say M epochs after which I change the loss.

Except the last case I developed a code which computes these losses but I am not sure whether it will work.  ie whether it will backpropagate? (code given below)

is it possible to give the weights adaptively when using combination of loss functions  ie can we train the network so that these weights are also learned ?

can this implementation be used for the above mentioned case 3 of changing loss functions
Sorry if anything given here is not clear or wrong. I am kind of new into PyTorch
criterion = _criterion
#training
prediction = model(input)
loss = criterion(prediction, target)
loss.backward()
class _criterion(nn.Module):
def __init__(self, model_type="CNN"):
super(_criterion).__init__()
self.model_type = model_type
def forward(self, pred, ref):
loss_1 = lambda x,y : nn.MSELoss(size_average=False)(x,y)
loss_2 = lambda x,y : nn.L1Loss(size_average=False)(x,y)
loss_3 = lambda x,y : nn.SmoothL1Loss(size_average=False)(x,y)
loss_4 = lambda x,y : L1_Charbonnier_loss_()(x,y) #userdefined
if opt.loss_function_order == 1:
loss_function_1 = get_loss_function(opt.loss_function_1)
loss = lambda x,y: 1*loss_function_1(x,y)
elif opt.loss_function_order == 2:
loss_function_1 = get_loss_function(opt.loss_function_1)
loss_function_2 = get_loss_function(opt.loss_function_2)
weight_1 = opt.loss_function_1_weight
weight_2 = opt.loss_function_2_weight
loss = lambda x,y: weight_1*loss_function_1(x,y) + weight_2*loss_function_2(x,y)
elif opt.loss_function_order == 3:
loss_function_1 = get_loss_function(opt.loss_function_1)
loss_function_2 = get_loss_function(opt.loss_function_2)
loss_function_3 = get_loss_function(opt.loss_function_3)
weight_1 = opt.loss_function_1_weight
weight_2 = opt.loss_function_2_weight
weight_3 = opt.loss_function_3_weight
loss = lambda x,y: weight_1*loss_function_1(x,y) + weight_2*loss_function_2(x,y) + weight_3*loss_function_3(x,y)
elif opt.loss_function_order == 4:
loss_function_1 = get_loss_function(opt.loss_function_1)
loss_function_2 = get_loss_function(opt.loss_function_2)
loss_function_3 = get_loss_function(opt.loss_function_3)
loss_function_4 = get_loss_function(opt.loss_function_4)
weight_1 = opt.loss_function_1_weight
weight_2 = opt.loss_function_2_weight
weight_3 = opt.loss_function_3_weight
weight_4 = opt.loss_function_4_weight
loss = lambda x,y: weight_1*loss_function_1(x,y) + weight_2*loss_function_2(x,y) + weight_3*loss_function_3(x,y) + weight_4*loss_function_4(x,y)
else:
raise Exception("_criterion : unable to interpret loss_function_order")
return loss(ref,pred), loss_1(ref,pred), loss_2(ref,pred), loss_3(ref,pred), loss_4(ref,pred)
def get_loss_function(loss):
if loss == "MSE":
criterion = nn.MSELoss(size_average=False)
elif loss == "MAE":
criterion = nn.L1Loss(size_average=False)
elif loss == "SmoothL1":
criterion = nn.SmoothL1Loss(size_average=False)
elif loss == "Charbonnier":
criterion = L1_Charbonnier_loss_()
else:
raise Exception("not implemented")
return criterion
class L1_Charbonnier_loss_(nn.Module):
def __init__(self):
super(L1_Charbonnier_loss_, self).__init__()
self.eps = 1e6
def forward(self, X, Y):
diff = torch.add(X, Y)
error = self.eps*((torch.sqrt(1+((diff * diff)/self.eps)))1)
loss = torch.sum(error)
return loss