Custom DoubleBackward method returns invalid shape

Hi All,

I just have a general question about writing custom autograd Functions that have in-built DoubleBackward methods. I’ve been writing a custom determinant function and I’m getting a mis-match in my DoubleBackward call.

From my understanding, the function takes a Tensor of shape A = [B, N, N] and outputs a scalar loss function so the output’s a Tensor of shape out = [B,], I then calculate a 1st order gradient for A which is of shape [B,N,N] (which I get), then in my Double Backward I then need to calculate a 2nd order gradient for A and my grad_output of the 1st gradient. In short it looks something like this,

import torch
from torch.autograd import Function

class Determinant(Function):
  
  @staticmethod
  def forward(ctx, A):
    ctx.save_for_backward(A)
    return torch.linalg.det(A)
    
  @staticmethod
  def backward(ctx, grad_output1):
    A, = ctx.saved_tensors
    return DeterminantBackward.apply(A, grad_output1)
    
class DeterminantBackward(Function):

  @staticmethod
  def forward(ctx, A, grad_output1):
    ctx.save_for_backward(A, grad_output1)
    """
    insert maths here
    ...
    calculate 1st order grad of A (A_grad) #shape [B,N,N]
    #shape of grad_output1 [B,]
    
    """
    return grad_output1*A_grad

  @staticmethod
  def backward(ctx, grad_output2):
    A, grad_output1 = ctx.saved_tensors
    """
    insert maths here
    ...
    calculate 2nd order grad w.r.t A (A_grad_grad) #shape [B,N,N]
    ...
    calculate 2nd order grad w.r.t grad_output1 (A_grad_grad_output1) #shape [B,]
    #shape of grad_output2 [B,N,N]
    """
    return grad_output2*A_grad_grad, grad_output2*A_grad_grad_output1

But a problem arises here. The first order gradient is of shape [B,N,N] and is [B,] element-wise multiplied by [B,N,N] (i.e. multiplying a batch of scalars across a batch of matrices). However, when differentiating this expression yields a batch of matrices which by definition has a different shape than the original scalars. So, I’m kinda at a loss (pun not intended) here as to why PyTorch requires a different shaped Tensor than what I derive analytically!

Any help would be greatly appreicated!

Thank you! :slight_smile: