Hi,
I have a model called ‘NET’ which is bunch convolutional layers.
as my loss function I am using soft-dtw which has a work-around to make DTW differentiable. the gradient of soft-dtw is a matrix with the size of my input data (x).
x: input data
Net: CNN
loss: stdw
derivative of loss w.r.t. y = G (matrix)
y = Net(x) --> output data
normally when we calculate loss and do loss.backward(), the gradient of the Loss on last layer w.r.t model params is calculated and back-propagated through the model parameters. however here, my loss function is not a pytorch loss library , it’s a self-define function (soft-dtw). To calculate the gradient of loss w.r.t Net.parameters, I use the chain rule:
derivative of Loss w.r.t y * derivative of y w.r.t model’s parameters
here if I define a Loss class that is inheriting from autograd, and do Loss.backward, is there a way to pass G to loss.backward() so that it uses G as the gradient of last layer and multiply it by the gradient of y w.r.t. model params?
y = Net(x)
loss = SDTWLoss(x,y)
# gradient of loss w.r.t y is Loss_grad = G ,a matrix where G.shape = x.shape
loss.backward()
optimizer.step()
# right now in my loss function I have implemented the code as the following:
class SDTWLoss():
def __init__(self, y_pred, y):
self.y_pred = y_pred
self.y = y
def forward(self):
_dtw = SoftDTW(self.y_pred, self.y ,device)
dtw_value, dtw_grad, dtw_E_mat= _dtw.compute()
self.dtw_grad = dtw_grad
dtw_loss = torch.mean(dtw_value)
return dtw_loss
def backward(self):
batch_size, _, n = self.y_pred.shape
G = jacobian_product_with_expected_sdtw_grad(self.y_pred, self.y, self.dtw_grad,device)
param_grad = []
for param in Net.parameters():
param_grad.append(param)
for k in range(batch_size):
for i in range(n):
Net.zero_grad()
self.y_pred[k,0,i].backward(retain_graph=True)
for j,param in enumerate(Net.parameters()):
param_grad[j] =param_grad[j]+ (G[k,0,i] * param.grad)
for j,param in enumerate(Net.parameters()):
param.grad = param_grad[j]
this code is extremely slow!
I think the fastest way is to have the option of passing G as an input to loss.backward:
‘’’ loss.backward(G) ‘’’’
Any suggestion is appreciated!