Max operation for specific elements in 4 dimensional array in parallel way

I am coding pytorch. Between the torch inference code, I add some peripheral code for my own interest. This code works fine, but it is too slow. The reason might be for-iteration. So, i need parallel and fast way of doing this.

It is okay to do this in tensor, numpy, or just python array.

I made a function named ‘selective_max’ to find maximum value in arrays. But the problem is that I don’t want a maximum among the whole arrays, but among specific candidates which is designated by mask array. Let me show the gist of this function(below shows the code itself)

input

x [batch_size , dim, num_points, k] : x is a original input, but this becomes [batch_size, num_points, dim, k] by ‘x.permute(0,2,1,3)’.

‘batch_size’ is a well-known definition in the deep learning society. In every mini batch, there is many points. And a single point is represented by ‘dim’ length feature. Each feature element, there is ‘k’ potential candidates which is target of ‘max function’ later.

mask [batch_size, num_points, k] : This array is similar with ‘x’ without ‘dim’. Its element is either ‘0’ or ‘1’. So, I use this as a mask signal, like do max operation only on ‘1’ masked value.

please see the code below with this explanation. I use 3 for-iteration. Lets say we target a specific batch and a specific point. For a specific batch and a specific point, ‘x’ has [dim, k] array. And mask has [k] array which consists of either ‘0’ or ‘1’. So, I extract the non-zero index from [k] array and use this for extracting specific elements in ‘x’ dim by dim(‘for k in range(dim)’).

toy example

Let’s say we are in the second for-iteration. So, we now have [dim, k] for ‘x’ and [k] for ‘mask’. For this toy example, i presume k=3 and dim=4. x = [[3,2,1],[5,6,4],[9,8,7],[12,11,10], k=[0,1,1]. So, output would be [2,6,8,11], not [3, 6, 9, 12].

previous try

I try { mask.repeat(0,0,1,0) *(element-wise mul) x } and do the max operation. But, ‘0’ might the max value, because the x might have minus values in all array. So, this would result in wrong operation.

Thank you in advance.

def selective_max2(x, mask): # x : [batch_size , dim, num_points, k] , mask : [batch_size, num_points, k]
batch_size = x.size(0)
dim = x.size(1)
num_points = x.size(2)
k = x.size(3)
device = torch.device('cuda')

x = x.permute(0,2,1,3) # : [batch, num_points, dim, k]
#print('permuted x dimension : ',x.size())

x = x.detach().cpu().numpy()
mask = mask.cpu().numpy()
output = np.zeros((batch_size,num_points,dim))

for i in range(batch_size):
 for j in range(num_points):
  query=np.nonzero(mask[i][j]) # among mask entries, we get the index of nonzero values.
  for k in range(dim): # for different k values, we get the max value.
   # query is index of nonzero values. so, using query, we can get the values that we want.
   output[i][j][k] = np.max(x[i][j][k][query])

output = torch.from_numpy(output).float().to(device=device)
output = output.permute(0,2,1).contiguous()
return output```

Hihi Hoho!

Rather than use multiplication to set your masked elements to zero
(which, as you note, will be incorrect), use subtraction to set your
masked values to a large negative value.

Please see this “masked argmax” post:

Best.

K. Frank

1 Like

Hello, Frank. Thanks for reply. I see your point and it makes sense. let you know the result.