Broadcasting issue in Custom DoubleBackward

Hi All,

I’m having a weird broadcasting issue when calculating my gradients with a Custom Autograd Function. I’ve written a custom version of the torch.slogdet with both a Custom Backward and Custom DoubleBackward extension. However, I’m having an issue with actually using the backward due to this error,

RuntimeError: The size of tensor a (4) must match the size of tensor b (6) at non-singleton dimension 3

I can see where it is, but I’m not sure how to resolve it. The code of the custom torch.slogdet is here,

class CustomSLogDeterminant(Function):
  
  @staticmethod
  def forward(ctx, A):
    ctx.save_for_backward(A)
    return torch.slogdet(A)    
    
  @staticmethod
  def backward(ctx, sgnbar, logabsdetbar):
    A, = ctx.saved_tensors
    return CustomSLogDeterminantBackward.apply(A, sgnbar, logabsdetbar) 
    
class CustomSLogDeterminantBackward(Function):

  @staticmethod
  def forward(ctx, A, sgnbar, logabsdetbar):
    #A shape [4096, 4, 6, 6]
    #sgnbar Shape [4096, 4]
    #logabsdetbar Shape [4096, 4]
    ctx.save_for_backward(A, sgnbar, logabsdetbar) 
    return logabsdetbar * torch.transpose(torch.linalg.inv(A), -2, -1)
    
  @staticmethod
  def backward(ctx, InverseTransposeAbar):
    A, sgnbar, logabsdetbar = ctx.saved_tensors
    
    InverseAbar = torch.transpose(InverseTransposeAbar, -2, -1)
    TransposeInverseA = torch.transpose(torch.linalg.inv(A), -2, -1)
        
    second_derivative = -(TransposeInverseA @ InverseAbar @ TransposeInverseA)
    return (logabsdetbar * second_derivative), None, InverseAbar @ TransposeInverseA

Now the error above clearly comes from return logabsdetbar * torch.transpose(torch.linalg.inv(A), -2, -1) as logabsdetbar has shape [4096, 4] and the other term has [4096,4,6,6]. Now I would need to broadcast this to each element within dim 2 and 3 for the second Tensor term. The 2nd Tensor of size [4096,4,6,6] represents a batch of 4096 inputs of 4 matrices which are of size [6, 6] each. I did modifying the line to unsqueeze dim 2 and 3 although that just pushed the error further along. It was modified as follows,

    return logabsdetbar.unsqueeze(2).unsqueeze(3) * torch.transpose(torch.linalg.inv(A), -2, -1)

But then that just returns a simply error of,

RuntimeError: Function CustomSLogDeterminantBackwardBackward returned an invalid gradient at index 2 - got [4096, 4, 6, 6] but expected shape compatible with [4096, 4]

How exactly can I solve this?

Thank you in advance!