# How to vectorize local soft-argmax function

Hi!

I am trying to implement an efficient parallel and vectorized function to compute the local soft-argmax for a batch of landmarks, where each landmark is a 2D heatmap.

For example, supose I have a tensor of shape: `[8,98,128,128]`. This would correspond to a tensor containing: 8 batches, each batch has 98 landmarks, each landmark contains a heatmap of 128x128.

I need to compute the local soft-argmax for each heatmap. The local soft-argmax takes the maximum of a heatmap, and computes the soft-argmax locally given some window size.

For example, consider the following scenario.

• I have a 5x5 example heatmap.
• Take a window size of 3x3 around the maximum (in the example is 0.9 at the center)
• Compute the soft-argmax around that maximum given the window size
``````# Original 7x7 heatmap                # Masked heatmap, all zeros except 3x3 around maximum
[0.01, 0.01,  0.01,  0.01, 0.01]      [0.00, 0.00, 0.00, 0.00, 0.00]
[0.01, 0.15,  0.09,  0.03, 0.01]      [0.00, 0.15, 0.09, 0.03, 0.00]
[0.01, 0.80, *0.90*, 0.65, 0.01] -->  [0.00, 0.80, 0.90, 0.65, 0.00]
[0.01, 0.13,  0.29,  0.33, 0.01]      [0.00, 0.13, 0.29, 0.33, 0.00]
[0.01, 0.01,  0.01,  0.01, 0.01]      [0.00, 0.00, 0.00, 0.00, 0.00]

``````

I can compute a mask to extract the maximum for each heatmap, but I am unable to efficiently slice its 3x3 neighbours too.

``````mask = (output==torch.amax(output, dim=(2,3), keepdim=True))
``````

Any help would be highly appreciated. Thanks!

Softargmax is missing.

``````batch_size = 1
nb_channels = 3
h = 10
w = 10
window = 3

conv = torch.nn.Conv2d(nb_channels, nb_channels, window, groups=nb_channels, padding=1, bias=False)
kernel = torch.ones(3, 3).view(1, 1, 3, 3).repeat(nb_channels, 1, 1, 1)
conv.weight = torch.nn.Parameter(kernel)

output = torch.rand(batch_size, nb_channels, h, w, requires_grad=True)

torch.set_printoptions(precision=2)
print(output)
``````

@Matias_Vasquez Wow! Thanks so much!!! This is an incredibly simple yet elegant solution.
I hadn’t thought of using a convolution layer, which of course, makes much sense.

I had a couple of questions regarding the code, so that I fully understand it:

• What does exactly do the `groups` parameter? My intuition is that it separates channels, so that the output is not merged along channels but kept separately. I would appreciate if you could confirm that.
`conv = torch.nn.Conv2d(nb_channels, nb_channels, window, groups=nb_channels, padding=1, bias=False)`

• Could you further develop on this line? I hadn’t seen this type of indexing before and I struggle to understand the internals.
`masked_output[windowed_masked_output>0] = output[windowed_masked_output>0]`

Lastly, to make it work regardless of the window size, I think there is a typo in these two lines, when defining the kernel, and setting padding to ‘same’.

``````conv = torch.nn.Conv2d(nb_channels, nb_channels, window, groups=nb_channels, padding='same', bias=False)
kernel = torch.ones(window, window).view(1, 1, window, window).repeat(nb_channels, 1, 1, 1)
``````

Again, thanks very much for the answer!!

1 Like

You are correct, this is exactly what is happening.

When you do a comparison such as `==`, `<`, `>` or others, you get a boolean mask (like you did in your first post).

Now when you use this mask as an index

``````masked_output[*MASK*]
``````

You are only accessing the values where the `mask`is `True`. So for your example we would get `9` values (if the maximum value is not on the edge/corner) which share memory with the indexed variable (`masked_output`in this case). So when we change these values, they change the memory shared, thus changing this variable.

Since `output`, `masked_output` and `windowed_masked_output` have all the same size, there will not be a problem using the same mask everywhere.

So in other words, we select the 9 values that are important to us (which are selected thanks to the `windowed_masked_output`) from the original `output` and we copy them into our `masked_output`, which has just the maximum value and `0` everywhere else.

## Softmax

now if you want to do softmax ONLY with these 9 values, then you can do something like this.

``````m = torch.nn.Softmax(dim=2)
``````

which will perform the softmax BEFORE we give it to our `masked_output`

Thanks, I forgot to change this.

Oh sorry, there is an error here. This will not work for edge/corner cases.

• Corner → causes an error.

## This should work better

``````m = torch.nn.Softmax(dim=2)