Hi All,
I just have a general question about writing custom autograd Functions that have in-built DoubleBackward
methods. I’ve been writing a custom determinant function and I’m getting a mis-match in my DoubleBackward
call.
From my understanding, the function takes a Tensor of shape A = [B, N, N]
and outputs a scalar loss function so the output’s a Tensor of shape out = [B,]
, I then calculate a 1st order gradient for A
which is of shape [B,N,N]
(which I get), then in my Double Backward I then need to calculate a 2nd order gradient for A
and my grad_output
of the 1st gradient. In short it looks something like this,
import torch
from torch.autograd import Function
class Determinant(Function):
@staticmethod
def forward(ctx, A):
ctx.save_for_backward(A)
return torch.linalg.det(A)
@staticmethod
def backward(ctx, grad_output1):
A, = ctx.saved_tensors
return DeterminantBackward.apply(A, grad_output1)
class DeterminantBackward(Function):
@staticmethod
def forward(ctx, A, grad_output1):
ctx.save_for_backward(A, grad_output1)
"""
insert maths here
...
calculate 1st order grad of A (A_grad) #shape [B,N,N]
#shape of grad_output1 [B,]
"""
return grad_output1*A_grad
@staticmethod
def backward(ctx, grad_output2):
A, grad_output1 = ctx.saved_tensors
"""
insert maths here
...
calculate 2nd order grad w.r.t A (A_grad_grad) #shape [B,N,N]
...
calculate 2nd order grad w.r.t grad_output1 (A_grad_grad_output1) #shape [B,]
#shape of grad_output2 [B,N,N]
"""
return grad_output2*A_grad_grad, grad_output2*A_grad_grad_output1
But a problem arises here. The first order gradient is of shape [B,N,N]
and is [B,]
element-wise multiplied by [B,N,N]
(i.e. multiplying a batch of scalars across a batch of matrices). However, when differentiating this expression yields a batch of matrices which by definition has a different shape than the original scalars. So, I’m kinda at a loss (pun not intended) here as to why PyTorch requires a different shaped Tensor than what I derive analytically!
Any help would be greatly appreicated!
Thank you!