I have a tensor, X:
X.shape = torch.Size([B,N,C,C])
I am trying to calculate the determinant of the elements along B
and N
, i.e for each CxC
matrix:
X_det = X.view(-1, C, C)
X_det = torch.det(X_det)
X_det = X.view(B, C).unsqueeze(-1).repeat(1,1, C)
This function is exceptionally slow on GPU, which is not the case when running on CPU.
I am using Pytorch 1.6.0.
What can I do to implement this function quickly on the GPU?