Find index of certain value in 3 dimension array in parallel way

i made a ‘zero_mask2’ function(i copy and paste the code below) in the middle of my pytorch file. But this is too slow. So, i want to find the better way.

First, Let me explain the gist of the function below.


‘all_idx_before’ ( dimension of [batch_size, number of points, 20] ) : batch_size is a well-known definition of mini batch in deep learning society. In every batch, there are many points like 1024. And, in every point, a point has 20 values.

‘idx’ ( dimension of [batch_size, number of points, 5] ) : batch and number of points are same as before. But, in this case, a point has only 5 values.

‘p’ : just an any number, for example, 7.

toy example

lets say, a point in ‘all_idx_before’ has (9, 10, 11, 12, 6, 7, 8, 14, 2, 3, 4, 5, 18, 19, 15, 16, 17, 13, 1, 20).

And a point in ‘idx’ has (10, 6, 2, 4, 18).

In this setting, i want to find the index of each value in ‘idx’. for example, (10) in ‘idx’ is second element in ‘all_idx_before’, (6) in ‘idx’ is fifth element in ‘all_idx_before’…

So, potential output of this point would be like (1, 5, 8, 10, 12), (‘second’ becomes ‘1’ because of number system in python.)

I do this logic in every batch and every point. So, i used for iteration. But this is too slow. Is there any way of doing this in parallel? I use numpy format. But i can use tensor if it is better.

Thank you.

def zero_mask2(idx, all_idx_before, p):

size_batch = len(idx)
size_points = len(idx[0])
size_neighbor = len(idx[0][0])

mask = torch.empty(size_batch, size_points, size_neighbor)

for i in range(size_batch):
 for j in range(size_points):
  for l in range(size_neighbor):
   mask[i][j][l] = np.where(all_idx_before[i][j] == idx[i][j][l])[0][0] 

mask_t = mask <= p # smaller than p or same : true , bigger than p : false
return mask_t

I would add a new dimension to those tensors and compare them:

all_idx_before = all_idx_before.view([size_batch, size_points, 20, 1])
idx = idx.view([size_batch, size_points, 1, 5])

absdiff = torch.abs(idx - all_idx_before) # [batch, points, 20, 5]

Then find indices with argmin:

output = absdiff.argmin(dim=2)
1 Like

Oh! That is a good way. I will try and let you know the result. Thank you.

----- after using your method ----
Wow! it became much more faster than using 3 for iteration…
But it seems it consumes more memory.

Thank you again.