Does custom piecewise loss function need a backwards computation?

I am defining a piecewise weight term for a distance loss. Although the function has discontinuities, there is always a gradient designed for each point. Thus, would I need a backward function for the following forward computation of my loss module?

def forward(self, x, y):

	# Compute squared dist
	d = torch.pow(x - y, 2)

	# Compute the first component of the loss
	loss = torch.sum(torch.mul(d[d<50], d[d < 50] / 50))

	# Compute the quadratic component of the loss
	loss += torch.sum(torch.mul(d[(d>=50)&(d<=100)], torch.pow(d[(d>=50) & (d<=100)],2)))

	# Compute the remaining constant component
	loss += torch.sum(d[d>100] * 10000)

(The numbers are currently arbitrary, but the idea is the same)

1 Like