How to vectorize local soft-argmax function


I am trying to implement an efficient parallel and vectorized function to compute the local soft-argmax for a batch of landmarks, where each landmark is a 2D heatmap.

For example, supose I have a tensor of shape: [8,98,128,128]. This would correspond to a tensor containing: 8 batches, each batch has 98 landmarks, each landmark contains a heatmap of 128x128.

I need to compute the local soft-argmax for each heatmap. The local soft-argmax takes the maximum of a heatmap, and computes the soft-argmax locally given some window size.

For example, consider the following scenario.

  • I have a 5x5 example heatmap.
  • Take a window size of 3x3 around the maximum (in the example is 0.9 at the center)
  • Compute the soft-argmax around that maximum given the window size
# Original 7x7 heatmap                # Masked heatmap, all zeros except 3x3 around maximum
[0.01, 0.01,  0.01,  0.01, 0.01]      [0.00, 0.00, 0.00, 0.00, 0.00]
[0.01, 0.15,  0.09,  0.03, 0.01]      [0.00, 0.15, 0.09, 0.03, 0.00]
[0.01, 0.80, *0.90*, 0.65, 0.01] -->  [0.00, 0.80, 0.90, 0.65, 0.00] 
[0.01, 0.13,  0.29,  0.33, 0.01]      [0.00, 0.13, 0.29, 0.33, 0.00]
[0.01, 0.01,  0.01,  0.01, 0.01]      [0.00, 0.00, 0.00, 0.00, 0.00]

And finally, compute the traditional global soft-argmax on the masked heatmap.

I can compute a mask to extract the maximum for each heatmap, but I am unable to efficiently slice its 3x3 neighbours too.

mask = (output==torch.amax(output, dim=(2,3), keepdim=True))

Any help would be highly appreciated. Thanks!

This might help you a little.

Softargmax is missing.

batch_size = 1
nb_channels = 3
h = 10
w = 10
window = 3

conv = torch.nn.Conv2d(nb_channels, nb_channels, window, groups=nb_channels, padding=1, bias=False)
kernel = torch.ones(3, 3).view(1, 1, 3, 3).repeat(nb_channels, 1, 1, 1)
with torch.no_grad():
    conv.weight = torch.nn.Parameter(kernel)

output = torch.rand(batch_size, nb_channels, h, w, requires_grad=True)
mask = (output==torch.amax(output, dim=(2,3), keepdim=True))

masked_output = output.clone()
masked_output[torch.logical_not(mask)] = 0

windowed_masked_output = conv(masked_output)
masked_output[windowed_masked_output>0] = output[windowed_masked_output>0]


@Matias_Vasquez Wow! Thanks so much!!! This is an incredibly simple yet elegant solution.
I hadn’t thought of using a convolution layer, which of course, makes much sense.

I had a couple of questions regarding the code, so that I fully understand it:

  • What does exactly do the groups parameter? My intuition is that it separates channels, so that the output is not merged along channels but kept separately. I would appreciate if you could confirm that.
    conv = torch.nn.Conv2d(nb_channels, nb_channels, window, groups=nb_channels, padding=1, bias=False)

  • Could you further develop on this line? I hadn’t seen this type of indexing before and I struggle to understand the internals.
    masked_output[windowed_masked_output>0] = output[windowed_masked_output>0]

Lastly, to make it work regardless of the window size, I think there is a typo in these two lines, when defining the kernel, and setting padding to ‘same’.

conv = torch.nn.Conv2d(nb_channels, nb_channels, window, groups=nb_channels, padding='same', bias=False)
kernel = torch.ones(window, window).view(1, 1, window, window).repeat(nb_channels, 1, 1, 1)

Again, thanks very much for the answer!!

1 Like

You are correct, this is exactly what is happening.

When you do a comparison such as ==, <, > or others, you get a boolean mask (like you did in your first post).

Now when you use this mask as an index


You are only accessing the values where the maskis True. So for your example we would get 9 values (if the maximum value is not on the edge/corner) which share memory with the indexed variable (masked_outputin this case). So when we change these values, they change the memory shared, thus changing this variable.

Since output, masked_output and windowed_masked_output have all the same size, there will not be a problem using the same mask everywhere.

So in other words, we select the 9 values that are important to us (which are selected thanks to the windowed_masked_output) from the original output and we copy them into our masked_output, which has just the maximum value and 0 everywhere else.


now if you want to do softmax ONLY with these 9 values, then you can do something like this.

m = torch.nn.Softmax(dim=2)
masked_output[windowed_masked_output>0] = m(output[windowed_masked_output>0].view(batch_size, nb_channels, -1)).view_as(output[windowed_masked_output>0])

which will perform the softmax BEFORE we give it to our masked_output

Thanks, I forgot to change this.

Oh sorry, there is an error here. This will not work for edge/corner cases.

  • Corner → causes an error.
  • Edge → wrong answer

This should work better

m = torch.nn.Softmax(dim=2)
masked_output[windowed_masked_output>0] = output[windowed_masked_output>0]
masked_output[windowed_masked_output==0] = -float('inf')
masked_output = m(masked_output.view(batch_size, nb_channels, -1)).view_as(output)

Thanks for pointing out that detail on the local soft-argmax! I wouldn’t have noticed otherwise, and it is important so that the values outside the window do not count towards that weighted sum.

Thanks again for everything!!!

1 Like