I am defining a piecewise weight term for a distance loss. Although the function has discontinuities, there is always a gradient designed for each point. Thus, would I need a backward
function for the following forward
computation of my loss module?
def forward(self, x, y):
# Compute squared dist
d = torch.pow(x - y, 2)
# Compute the first component of the loss
loss = torch.sum(torch.mul(d[d<50], d[d < 50] / 50))
# Compute the quadratic component of the loss
loss += torch.sum(torch.mul(d[(d>=50)&(d<=100)], torch.pow(d[(d>=50) & (d<=100)],2)))
# Compute the remaining constant component
loss += torch.sum(d[d>100] * 10000)
(The numbers are currently arbitrary, but the idea is the same)