PyTorch newbie here. I have a tensor “AAA” with 5 dimensions and the length of each dimension is 1, 2, Z, Y, X.
To make the question easy, let’s say we divide AAA into two tensors by the second dimension so the sizes of resulting tensors A, B would look like A = [Z, Y, X] and B = [Z, Y, X]
I want to check if each element is in B > A (elementwise) and if each element is in B > a_threashold (a scalar value).
If both of those conditions are true, I want this value from B and its index. Now, if we can do this in the original tensor AAA without creating two-child tensors A and B is probably more efficient I guess.
I can easily do this by running three nested for loops and comparing each element like
for z in range(AAA.size(dim=2)): for y in range(AAA.size(dim=3)): for x in range(AAA.size(dim=4)): if AAA[0, 0, z, y, x] < AAA[0, 1, z, y, x] and AAA[0, 1, z, y, x] > a_threashold: print(str(AAA[0, 1, z, y, x]) + "index" + str(z) + "\t" + str(y) + "\t" + str(z))
but this is unbelievably slow (I have like 500 million floating-point numbers in AAA).
I was thinking of using python multiprocessing but there might be a better PyTorch way?
Thank you so much!!!