Hi!

I am trying to implement an efficient parallel and vectorized function to compute the local soft-argmax for a batch of landmarks, where each landmark is a 2D heatmap.

For example, supose I have a tensor of shape: `[8,98,128,128]`

. This would correspond to a tensor containing: **8 batches**, each batch has **98 landmarks**, each landmark contains a **heatmap of 128x128**.

I need to compute the local soft-argmax for each heatmap. The **local soft-argmax** takes the maximum of a heatmap, and computes the soft-argmax locally given some window size.

For example, consider the following scenario.

- I have a 5x5 example heatmap.
- Take a window size of 3x3 around the maximum (
*in the example is 0.9 at the center*) - Compute the soft-argmax around that maximum given the window size

```
# Original 7x7 heatmap # Masked heatmap, all zeros except 3x3 around maximum
[0.01, 0.01, 0.01, 0.01, 0.01] [0.00, 0.00, 0.00, 0.00, 0.00]
[0.01, 0.15, 0.09, 0.03, 0.01] [0.00, 0.15, 0.09, 0.03, 0.00]
[0.01, 0.80, *0.90*, 0.65, 0.01] --> [0.00, 0.80, 0.90, 0.65, 0.00]
[0.01, 0.13, 0.29, 0.33, 0.01] [0.00, 0.13, 0.29, 0.33, 0.00]
[0.01, 0.01, 0.01, 0.01, 0.01] [0.00, 0.00, 0.00, 0.00, 0.00]
```

And finally, compute the traditional global soft-argmax on the masked heatmap.

I can compute a mask to extract the maximum for each heatmap, but I am unable to efficiently slice its 3x3 neighbours too.

```
mask = (output==torch.amax(output, dim=(2,3), keepdim=True))
```

Any help would be highly appreciated. Thanks!