Hi All,
I’m having a weird broadcasting issue when calculating my gradients with a Custom Autograd Function. I’ve written a custom version of the torch.slogdet
with both a Custom Backward and Custom DoubleBackward extension. However, I’m having an issue with actually using the backward due to this error,
RuntimeError: The size of tensor a (4) must match the size of tensor b (6) at non-singleton dimension 3
I can see where it is, but I’m not sure how to resolve it. The code of the custom torch.slogdet
is here,
class CustomSLogDeterminant(Function):
@staticmethod
def forward(ctx, A):
ctx.save_for_backward(A)
return torch.slogdet(A)
@staticmethod
def backward(ctx, sgnbar, logabsdetbar):
A, = ctx.saved_tensors
return CustomSLogDeterminantBackward.apply(A, sgnbar, logabsdetbar)
class CustomSLogDeterminantBackward(Function):
@staticmethod
def forward(ctx, A, sgnbar, logabsdetbar):
#A shape [4096, 4, 6, 6]
#sgnbar Shape [4096, 4]
#logabsdetbar Shape [4096, 4]
ctx.save_for_backward(A, sgnbar, logabsdetbar)
return logabsdetbar * torch.transpose(torch.linalg.inv(A), -2, -1)
@staticmethod
def backward(ctx, InverseTransposeAbar):
A, sgnbar, logabsdetbar = ctx.saved_tensors
InverseAbar = torch.transpose(InverseTransposeAbar, -2, -1)
TransposeInverseA = torch.transpose(torch.linalg.inv(A), -2, -1)
second_derivative = -(TransposeInverseA @ InverseAbar @ TransposeInverseA)
return (logabsdetbar * second_derivative), None, InverseAbar @ TransposeInverseA
Now the error above clearly comes from return logabsdetbar * torch.transpose(torch.linalg.inv(A), -2, -1)
as logabsdetbar
has shape [4096, 4]
and the other term has [4096,4,6,6]
. Now I would need to broadcast this to each element within dim
2 and 3 for the second Tensor
term. The 2nd Tensor
of size [4096,4,6,6]
represents a batch of 4096
inputs of 4
matrices which are of size [6, 6]
each. I did modifying the line to unsqueeze dim
2 and 3 although that just pushed the error further along. It was modified as follows,
return logabsdetbar.unsqueeze(2).unsqueeze(3) * torch.transpose(torch.linalg.inv(A), -2, -1)
But then that just returns a simply error of,
RuntimeError: Function CustomSLogDeterminantBackwardBackward returned an invalid gradient at index 2 - got [4096, 4, 6, 6] but expected shape compatible with [4096, 4]
How exactly can I solve this?
Thank you in advance!