The following code does not work if N < 32:
N = 31
a = torch.ones(N, N)
indices = [(i, i) for i in range(N)]
print(a[indices])
The error it outputs is:
IndexError: too many indices for tensor of dimension 2
But if N >= 32, it works. I suspect this might be a bug in PyTorch itself?
My PyTorch version is 1.11.0.
I found a similar topic discussing it here: python - Pytorch tensor indexing error for sizes M < 32? - Stack Overflow