Analogue to accumarray (matlab) - sparse pooling


I’m wondering if there’s a way to emulate the behavior of accumarray in Matlab. More specifically, my use case is to project 2D image features to the ground-plane as it has been done in MapNet. The exact specifications are as follows:

feat - image features with size (B, F, H, W) where B is batch-size and F is feature size
loc - a tensor of locations with size (B, 2, H, W) where loc[b, :, i, j] gives a 2D coordinate
output_dims - a tuple containing M, N (output dimensions)

projection - (B, F, M, N) tensor of pooled features

Feature feat[b, :, i, j] needs to be written to projection[b, :, m, n] where (m, n) = loc[b, :, i, j]. If there are multiple features corresponding to the same (m, n) output location, the pool them according to some function (max / sum / avg).

My current approach to this is as follows:

loc_rshp = loc.permute(0, 2, 3, 1).contiguous().view(-1, 2).long() # (B*H*W, 2)
full_projection = torch.zeros(B*H*W, M, N, F)
feats_rshp = feats.permute(0, 2, 3, 1).contiguous().view(-1, F) # (B*H*W, F)

# Advanced indexing
full_projection[range(B*H*W), loc_rshp[:, 0], loc_rshp[:, 1], :] = feats_rshp 

full_projection = full_projection.view(B, H*W, M, N, F).permute(0, 1, 4, 2, 3).contiguous()
projection = max_pool_ignore(full_projection, dim=1) # max_pool along dim=1, ignore zeros

It’s quite inefficient because I have to create full_projection each time, and this takes up most of the processing time in this part of the code. I would appreciate any suggestions on how this can be done more efficiently. Thanks!

Note: max_pool_ignore performs max_pooling along dimension 1, ignoring the zeros. This is the exact function:

def max_pool_ignore(input, dim=0, ignore_val=0.0, large_value=2e10):
    Max pooling after ignore a specific value. If all values along the
    dimension are ignore_val, then that will be the output for that index.
    ignored = (input == ignore_val).float()
    input_  = - ignored * float(large_value) + (1 - ignored) * input
    output, _ = torch.max(input_, dim=dim)
    output[output == -float(large_value)] = ignore_val
    return output