I have a model called ‘NET’ which is bunch convolutional layers.
as my loss function I am using soft-dtw which has a work-around to make DTW differentiable. the gradient of soft-dtw is a matrix with the size of my input data (x).
x: input data
derivative of loss w.r.t. y = G (matrix)
y = Net(x) --> output data
normally when we calculate loss and do loss.backward(), the gradient of the Loss on last layer w.r.t model params is calculated and back-propagated through the model parameters. however here, my loss function is not a pytorch loss library , it’s a self-define function (soft-dtw). To calculate the gradient of loss w.r.t Net.parameters, I use the chain rule:
derivative of Loss w.r.t y * derivative of y w.r.t model’s parameters
here if I define a Loss class that is inheriting from autograd, and do Loss.backward, is there a way to pass G to loss.backward() so that it uses G as the gradient of last layer and multiply it by the gradient of y w.r.t. model params?
y = Net(x) loss = SDTWLoss(x,y) # gradient of loss w.r.t y is Loss_grad = G ,a matrix where G.shape = x.shape loss.backward() optimizer.step() # right now in my loss function I have implemented the code as the following: class SDTWLoss(): def __init__(self, y_pred, y): self.y_pred = y_pred self.y = y def forward(self): _dtw = SoftDTW(self.y_pred, self.y ,device) dtw_value, dtw_grad, dtw_E_mat= _dtw.compute() self.dtw_grad = dtw_grad dtw_loss = torch.mean(dtw_value) return dtw_loss def backward(self): batch_size, _, n = self.y_pred.shape G = jacobian_product_with_expected_sdtw_grad(self.y_pred, self.y, self.dtw_grad,device) param_grad =  for param in Net.parameters(): param_grad.append(param) for k in range(batch_size): for i in range(n): Net.zero_grad() self.y_pred[k,0,i].backward(retain_graph=True) for j,param in enumerate(Net.parameters()): param_grad[j] =param_grad[j]+ (G[k,0,i] * param.grad) for j,param in enumerate(Net.parameters()): param.grad = param_grad[j]
this code is extremely slow!
I think the fastest way is to have the option of passing G as an input to loss.backward:
‘’’ loss.backward(G) ‘’’’
Any suggestion is appreciated!