Manual backprob of gradient in last layer to loss


I have a model called ‘NET’ which is bunch convolutional layers.
as my loss function I am using soft-dtw which has a work-around to make DTW differentiable. the gradient of soft-dtw is a matrix with the size of my input data (x).

x: input data
Net: CNN
loss: stdw
derivative of loss w.r.t. y = G (matrix)
y = Net(x) --> output data

normally when we calculate loss and do loss.backward(), the gradient of the Loss on last layer w.r.t model params is calculated and back-propagated through the model parameters. however here, my loss function is not a pytorch loss library , it’s a self-define function (soft-dtw). To calculate the gradient of loss w.r.t Net.parameters, I use the chain rule:

derivative of Loss w.r.t y * derivative of y w.r.t model’s parameters
here if I define a Loss class that is inheriting from autograd, and do Loss.backward, is there a way to pass G to loss.backward() so that it uses G as the gradient of last layer and multiply it by the gradient of y w.r.t. model params?

y = Net(x)
loss = SDTWLoss(x,y)
 # gradient of loss w.r.t y is  Loss_grad = G  ,a matrix where G.shape = x.shape

# right now in my loss function I have implemented the code as the following:
class SDTWLoss():
    def __init__(self, y_pred, y):
        self.y_pred = y_pred
        self.y = y
    def forward(self):
        _dtw = SoftDTW(self.y_pred, self.y ,device)
        dtw_value, dtw_grad, dtw_E_mat= _dtw.compute()
        self.dtw_grad = dtw_grad
        dtw_loss = torch.mean(dtw_value)
        return dtw_loss

    def backward(self):
        batch_size, _, n = self.y_pred.shape
        G = jacobian_product_with_expected_sdtw_grad(self.y_pred, self.y, self.dtw_grad,device)
        param_grad = []
        for param in Net.parameters():
        for k in range(batch_size):
            for i in range(n):
                for j,param in enumerate(Net.parameters()):
                     param_grad[j] =param_grad[j]+ (G[k,0,i] * param.grad)

        for j,param in enumerate(Net.parameters()):
            param.grad = param_grad[j]

this code is extremely slow!
I think the fastest way is to have the option of passing G as an input to loss.backward:
‘’’ loss.backward(G) ‘’’’
Any suggestion is appreciated!

Have you tried this?

my understanding is that this example is implementing Relu activation function. I need to implement Loss function for specific case where the last layer’s derivative is calculated manually and is back propagated through the network using chain rule.

Yeah. This is a relu example, but you can extend this to your own problem. The returned value in backward are the derivatives of the input of the forward, which is the output of the second to the last layer in your case. Your loss is implemented in the forward, or constructed by using the output of the forward. You just need to write your last layer’s derivative manually in the backward.

thanks for clarifying. I am still not completely there.
so let’s say I have the following:

class SDTWLoss(torch.autograd.Function):

def forward( ctx, y_pred, y):
    _dtw = SoftDTW(y_pred, y ,device)
    dtw_value, dtw_grad, dtw_E_mat= _dtw.compute() # calculate loss value for y_prediction and y_actual
    G = jacobian_product_with_expected_sdtw_grad(y_pred, y, dtw_grad,device) # calculate derivative of loss w.r.t. y_pred (last layer of my graph) 
    ctx.save_for_backward(G) #save derviative of Loss w.r.t y for backward
    dtw_loss = torch.mean(dtw_value)

    return dtw_loss

def backward( ctx, grad_output ):

    G = ctx.saved_tensors
    grad_input = grad_output.clone() # is this the derivative of the layer before the last layer w.r.t model's parameters???
    # if so, I should multiply G by grad_input here, right?
    grad_input = G*grad_input

    return grad_input

Does this code make sense?

The logic is correct here.

unfortunately it’s not working yet!
if grad_output is the gradient w.r.t y_predicted, I am calculating it manually (matrix G). what should I multiply it to , to get the gradient w.r.t. model’s params?

You need to do return grad_input, None in your backward as you do not need grads for true y.