You could have a reference on this.
The demo in your link is really friendly. And I tried to simulate it to meet your requirement, I am not sure if it works or not. I think using custom function
is a better way (you mentioned) because it makes us clearer about some internal mechanism.
class WeightedLoss(nn.Module):
def __init__(self, weight):
super(WeightedLoss, self).__init__()
self.weight = weight
self.lossfunc = WeightedLossFunc.apply
def forward(self, input, target):
return self.lossfunc(input, target, self.weight)
class WeightedLossFunc(Function):
@staticmethod
def forward(ctx, input, target, weight):
ctx.save_for_backward(input, target, weight)
return ((target-input)**2 * weight).view(target.size(0), -1).sum(dim=1, keepdims=True).mean()
@staticmethod
def backward(ctx, grad_output):
input, target, weight = ctx.saved_tensors
batch_size = target.size(0)
input_grad, target_grad, weight_grad = None, None, None
differ = target - input
input_grad = -2 * differ * weight / batch_size
weight_grad = differ ** 2 / batch_size
return input_grad, target_grad, weight_grad
Here is another thread about build our own custom loss functions, I did not have time to read it but hope this helps you.
If you find a better solution, let me know.
Thank you!