i made a ‘zero_mask2’ function(i copy and paste the code below) in the middle of my pytorch file. But this is too slow. So, i want to find the better way.
First, Let me explain the gist of the function below.
input
‘all_idx_before’ ( dimension of [batch_size, number of points, 20] ) : batch_size is a well-known definition of mini batch in deep learning society. In every batch, there are many points like 1024. And, in every point, a point has 20 values.
‘idx’ ( dimension of [batch_size, number of points, 5] ) : batch and number of points are same as before. But, in this case, a point has only 5 values.
‘p’ : just an any number, for example, 7.
toy example
lets say, a point in ‘all_idx_before’ has (9, 10, 11, 12, 6, 7, 8, 14, 2, 3, 4, 5, 18, 19, 15, 16, 17, 13, 1, 20).
And a point in ‘idx’ has (10, 6, 2, 4, 18).
In this setting, i want to find the index of each value in ‘idx’. for example, (10) in ‘idx’ is second element in ‘all_idx_before’, (6) in ‘idx’ is fifth element in ‘all_idx_before’…
So, potential output of this point would be like (1, 5, 8, 10, 12), (‘second’ becomes ‘1’ because of number system in python.)
I do this logic in every batch and every point. So, i used for iteration. But this is too slow. Is there any way of doing this in parallel? I use numpy format. But i can use tensor if it is better.
Thank you.
def zero_mask2(idx, all_idx_before, p):
size_batch = len(idx)
size_points = len(idx[0])
size_neighbor = len(idx[0][0])
mask = torch.empty(size_batch, size_points, size_neighbor)
for i in range(size_batch):
for j in range(size_points):
for l in range(size_neighbor):
mask[i][j][l] = np.where(all_idx_before[i][j] == idx[i][j][l])[0][0]
mask_t = mask <= p # smaller than p or same : true , bigger than p : false
return mask_t