Hi, I am trying to implementing a customized loss function as follows, can it work without a customized backward routine ? Actually, how do I know if the operation can be supported by autograd automatically so that I don’t need to specify backward pass?
class Loss(torch.autograd.Function):
'''
Implement the loss function from output from RNN.
Ref paper: https://arxiv.org/abs/1308.0850
'''
def __init__(self):
'''
x is sequence of coordinates with dim (batch, seq_length, 3).
Parameters are sequence of output from rnn with dim (batch, seq_length, 128).
'''
self.e = [] # predicted end of stroke probability scalar
self.m1 = [] # vector of means for x1 with len 20
self.m2 = [] # vector of means for x2 with len 20
self.pi = [] # vector of mixture density network coefficients with len 20
self.rho = [] # vector of correlation with len 20
self.s1 = [] # vector of standard deviation for x1 with len 20
self.s2 = [] # vector of standard deviation for x2 with len 20
self.x1 = [] # x1 coordinate at t+1
self.x2 = [] # x2 coordinates at t + 1
self.et = [] # end of probability indicator from ground truth
self.parameters = []
self.batch = 0 #batch size
self.seq_length = 0 # reduce by 1 because loss is caculated at t+1 timestamp
def forward(self, x, para):
'''
Implement eq 26 of ref paper for single time stamp.
'''
self.save_for_backward(para)
total_loss = 0
for i in range(self.seq_length):
# prepare parameters
self.__get_para(i, x, para)
normalpdf = self.__para2normal(self.x1, self.x2, self.m1, self.m2, self.s1, self.s2, self.rho) #dim (n_batch, 20)
single_loss = self.__singleLoss(normalpdf)
total_loss += single_loss
return total_loss
def __get_para(self, i, x, para):
'''
Slice and process parameters to the right form.
Implementing eq 18-23 of ref paper.
'''
self.batch = torch.size(x)[0]
self.e = torch.sigmoid(-para[:,i,0]) # eq 18
self.parameters = para
self.seq_length = torch.size(x)[1] -1 # reduce by 1 because loss is caculated at t+1 timestamp
# slice remaining parameters and training inputs
self.pi, self.m1, self.m2, self.s1, self.s2, self.rho = torch.split(self.parameters[:,i,1:], 6, dim = 1)
self.x1 = x[:,i+1,0]
self.x2 = x[:,i+1,1]
self.et = x[:,i+1,2]
## process parameters
# pi
max_pi = torch.max(self.pi, dim = 1)[0]
max_pi = max_pi.expand_as(self.pi)
self.pi.sub_(max_pi)
red_sum = torch.sum(self.pi, dim = 1).expand_as(self.pi)
self.pi.div_(red_sum)
# sd
self.s1.exp_()
self.s2.exp_()
# rho
self.rho.tanh_()
# reshape ground truth x1, x2 to match m1, m2 because broadcasting is currently not supported by pytorch
self.x1.expand_as(self.m1)
self.x2.expand_as(self.m2)
def __para2normal(self, x1, x2, m1, m2, s1, s2, rho):
'''
Implement eq 24, 25 of ref paper.
'''
norm1 = x1.sub(m1)
norm2 = x2.sub(m2)
s1s2 = torch.mul(s1, s2)
z = torch.pow(torch.div(norm1, s1), 2) + torch.pow(torch.div(norm2, s2), 2) - \
2*torch.div(torch.mul(pho, torch.mul(norm1, norm2)), s1s2)
negRho = 1 - torch.pow(rho, 2)
expPart = torch.exp(torch.div(-z, torch.mul(negRho, 2)))
coef = 2*np.pi*torch.mul(s1s2, torch.sqrt(negRho))
result = torch.div(expPart, coef)
return result
def __singleLoss(self, normalpdf):
'''
Calculate loss for single time stamp.
Input: normalpdf (n_batch, 20).
'''
epsilon = 1e-20 # floor of loss from mixture density component since initial loss could be zero
mix_den_loss = torch.mul(self.pi, normalpdf)
red_sum_loss = torch.sum(mix_den_loss) # sum for all batch
end_loss = torch.mul(self.e, self.et) + torch.mul(1-self.e, 1 - self.et)
total_loss = -torch.log(tf.max(red_sum_loss, epsilon)) - torch.log(end_loss)
return total_loss/self.batch