I was working on an image restoration task and I considered multiple loss functions . My plan was to consider 3 routes:
1: Use multiple losses for monitoring but use only a few for training itself
2: Out of those loss functions that are used for training, I needed to give each a weight - currently I am specifying the weight. I would like to make that parameter adaptive.
3: If in between training - if I observe a saturation I would like to change the loss function . or its components. Currently I considered re-training a network (if in the first training the model saturated) such that it trained with a particular loss function for the the first say M epochs after which I change the loss.
-
Except the last case I developed a code which computes these losses but I am not sure whether it will work. - ie whether it will backpropagate? (code given below)
-
is it possible to give the weights adaptively when using combination of loss functions - ie can we train the network so that these weights are also learned ?
-
can this implementation be used for the above mentioned case 3 of changing loss functions
Sorry if anything given here is not clear or wrong. I am kind of new into PyTorch
criterion = _criterion
#--training
prediction = model(input)
loss = criterion(prediction, target)
loss.backward()
class _criterion(nn.Module):
def __init__(self, model_type="CNN"):
super(_criterion).__init__()
self.model_type = model_type
def forward(self, pred, ref):
loss_1 = lambda x,y : nn.MSELoss(size_average=False)(x,y)
loss_2 = lambda x,y : nn.L1Loss(size_average=False)(x,y)
loss_3 = lambda x,y : nn.SmoothL1Loss(size_average=False)(x,y)
loss_4 = lambda x,y : L1_Charbonnier_loss_()(x,y) #user-defined
if opt.loss_function_order == 1:
loss_function_1 = get_loss_function(opt.loss_function_1)
loss = lambda x,y: 1*loss_function_1(x,y)
elif opt.loss_function_order == 2:
loss_function_1 = get_loss_function(opt.loss_function_1)
loss_function_2 = get_loss_function(opt.loss_function_2)
weight_1 = opt.loss_function_1_weight
weight_2 = opt.loss_function_2_weight
loss = lambda x,y: weight_1*loss_function_1(x,y) + weight_2*loss_function_2(x,y)
elif opt.loss_function_order == 3:
loss_function_1 = get_loss_function(opt.loss_function_1)
loss_function_2 = get_loss_function(opt.loss_function_2)
loss_function_3 = get_loss_function(opt.loss_function_3)
weight_1 = opt.loss_function_1_weight
weight_2 = opt.loss_function_2_weight
weight_3 = opt.loss_function_3_weight
loss = lambda x,y: weight_1*loss_function_1(x,y) + weight_2*loss_function_2(x,y) + weight_3*loss_function_3(x,y)
elif opt.loss_function_order == 4:
loss_function_1 = get_loss_function(opt.loss_function_1)
loss_function_2 = get_loss_function(opt.loss_function_2)
loss_function_3 = get_loss_function(opt.loss_function_3)
loss_function_4 = get_loss_function(opt.loss_function_4)
weight_1 = opt.loss_function_1_weight
weight_2 = opt.loss_function_2_weight
weight_3 = opt.loss_function_3_weight
weight_4 = opt.loss_function_4_weight
loss = lambda x,y: weight_1*loss_function_1(x,y) + weight_2*loss_function_2(x,y) + weight_3*loss_function_3(x,y) + weight_4*loss_function_4(x,y)
else:
raise Exception("_criterion : unable to interpret loss_function_order")
return loss(ref,pred), loss_1(ref,pred), loss_2(ref,pred), loss_3(ref,pred), loss_4(ref,pred)
def get_loss_function(loss):
if loss == "MSE":
criterion = nn.MSELoss(size_average=False)
elif loss == "MAE":
criterion = nn.L1Loss(size_average=False)
elif loss == "Smooth-L1":
criterion = nn.SmoothL1Loss(size_average=False)
elif loss == "Charbonnier":
criterion = L1_Charbonnier_loss_()
else:
raise Exception("not implemented")
return criterion
class L1_Charbonnier_loss_(nn.Module):
def __init__(self):
super(L1_Charbonnier_loss_, self).__init__()
self.eps = 1e-6
def forward(self, X, Y):
diff = torch.add(X, -Y)
error = self.eps*((torch.sqrt(1+((diff * diff)/self.eps)))-1)
loss = torch.sum(error)
return loss