Hi!
I am trying to implement an efficient parallel and vectorized function to compute the local soft-argmax for a batch of landmarks, where each landmark is a 2D heatmap.
For example, supose I have a tensor of shape: [8,98,128,128]
. This would correspond to a tensor containing: 8 batches, each batch has 98 landmarks, each landmark contains a heatmap of 128x128.
I need to compute the local soft-argmax for each heatmap. The local soft-argmax takes the maximum of a heatmap, and computes the soft-argmax locally given some window size.
For example, consider the following scenario.
- I have a 5x5 example heatmap.
- Take a window size of 3x3 around the maximum (in the example is 0.9 at the center)
- Compute the soft-argmax around that maximum given the window size
# Original 7x7 heatmap # Masked heatmap, all zeros except 3x3 around maximum
[0.01, 0.01, 0.01, 0.01, 0.01] [0.00, 0.00, 0.00, 0.00, 0.00]
[0.01, 0.15, 0.09, 0.03, 0.01] [0.00, 0.15, 0.09, 0.03, 0.00]
[0.01, 0.80, *0.90*, 0.65, 0.01] --> [0.00, 0.80, 0.90, 0.65, 0.00]
[0.01, 0.13, 0.29, 0.33, 0.01] [0.00, 0.13, 0.29, 0.33, 0.00]
[0.01, 0.01, 0.01, 0.01, 0.01] [0.00, 0.00, 0.00, 0.00, 0.00]
And finally, compute the traditional global soft-argmax on the masked heatmap.
I can compute a mask to extract the maximum for each heatmap, but I am unable to efficiently slice its 3x3 neighbours too.
mask = (output==torch.amax(output, dim=(2,3), keepdim=True))
Any help would be highly appreciated. Thanks!