How to vectorize local soft-argmax function

Hi!

I am trying to implement an efficient parallel and vectorized function to compute the local soft-argmax for a batch of landmarks, where each landmark is a 2D heatmap.

For example, supose I have a tensor of shape: [8,98,128,128]. This would correspond to a tensor containing: 8 batches, each batch has 98 landmarks, each landmark contains a heatmap of 128x128.

I need to compute the local soft-argmax for each heatmap. The local soft-argmax takes the maximum of a heatmap, and computes the soft-argmax locally given some window size.

For example, consider the following scenario.

  • I have a 5x5 example heatmap.
  • Take a window size of 3x3 around the maximum (in the example is 0.9 at the center)
  • Compute the soft-argmax around that maximum given the window size
# Original 7x7 heatmap                # Masked heatmap, all zeros except 3x3 around maximum
[0.01, 0.01,  0.01,  0.01, 0.01]      [0.00, 0.00, 0.00, 0.00, 0.00]
[0.01, 0.15,  0.09,  0.03, 0.01]      [0.00, 0.15, 0.09, 0.03, 0.00]
[0.01, 0.80, *0.90*, 0.65, 0.01] -->  [0.00, 0.80, 0.90, 0.65, 0.00] 
[0.01, 0.13,  0.29,  0.33, 0.01]      [0.00, 0.13, 0.29, 0.33, 0.00]
[0.01, 0.01,  0.01,  0.01, 0.01]      [0.00, 0.00, 0.00, 0.00, 0.00]

And finally, compute the traditional global soft-argmax on the masked heatmap.

I can compute a mask to extract the maximum for each heatmap, but I am unable to efficiently slice its 3x3 neighbours too.

mask = (output==torch.amax(output, dim=(2,3), keepdim=True))

Any help would be highly appreciated. Thanks!

This might help you a little.

Softargmax is missing.

batch_size = 1
nb_channels = 3
h = 10
w = 10
window = 3

conv = torch.nn.Conv2d(nb_channels, nb_channels, window, groups=nb_channels, padding=1, bias=False)
kernel = torch.ones(3, 3).view(1, 1, 3, 3).repeat(nb_channels, 1, 1, 1)
with torch.no_grad():
    conv.weight = torch.nn.Parameter(kernel)


output = torch.rand(batch_size, nb_channels, h, w, requires_grad=True)
mask = (output==torch.amax(output, dim=(2,3), keepdim=True))

masked_output = output.clone()
masked_output[torch.logical_not(mask)] = 0

windowed_masked_output = conv(masked_output)
masked_output[windowed_masked_output>0] = output[windowed_masked_output>0]

torch.set_printoptions(precision=2)
print(output)
print(masked_output)

@Matias_Vasquez Wow! Thanks so much!!! This is an incredibly simple yet elegant solution.
I hadn’t thought of using a convolution layer, which of course, makes much sense.

I had a couple of questions regarding the code, so that I fully understand it:

  • What does exactly do the groups parameter? My intuition is that it separates channels, so that the output is not merged along channels but kept separately. I would appreciate if you could confirm that.
    conv = torch.nn.Conv2d(nb_channels, nb_channels, window, groups=nb_channels, padding=1, bias=False)

  • Could you further develop on this line? I hadn’t seen this type of indexing before and I struggle to understand the internals.
    masked_output[windowed_masked_output>0] = output[windowed_masked_output>0]


Lastly, to make it work regardless of the window size, I think there is a typo in these two lines, when defining the kernel, and setting padding to ‘same’.

conv = torch.nn.Conv2d(nb_channels, nb_channels, window, groups=nb_channels, padding='same', bias=False)
kernel = torch.ones(window, window).view(1, 1, window, window).repeat(nb_channels, 1, 1, 1)

Again, thanks very much for the answer!!

1 Like

You are correct, this is exactly what is happening.

When you do a comparison such as ==, <, > or others, you get a boolean mask (like you did in your first post).

Now when you use this mask as an index

masked_output[*MASK*]

You are only accessing the values where the maskis True. So for your example we would get 9 values (if the maximum value is not on the edge/corner) which share memory with the indexed variable (masked_outputin this case). So when we change these values, they change the memory shared, thus changing this variable.

Since output, masked_output and windowed_masked_output have all the same size, there will not be a problem using the same mask everywhere.

So in other words, we select the 9 values that are important to us (which are selected thanks to the windowed_masked_output) from the original output and we copy them into our masked_output, which has just the maximum value and 0 everywhere else.

Softmax

now if you want to do softmax ONLY with these 9 values, then you can do something like this.

m = torch.nn.Softmax(dim=2)
masked_output[windowed_masked_output>0] = m(output[windowed_masked_output>0].view(batch_size, nb_channels, -1)).view_as(output[windowed_masked_output>0])

which will perform the softmax BEFORE we give it to our masked_output


Thanks, I forgot to change this.

Oh sorry, there is an error here. This will not work for edge/corner cases.

  • Corner → causes an error.
  • Edge → wrong answer

This should work better

m = torch.nn.Softmax(dim=2)
masked_output[windowed_masked_output>0] = output[windowed_masked_output>0]
masked_output[windowed_masked_output==0] = -float('inf')
masked_output = m(masked_output.view(batch_size, nb_channels, -1)).view_as(output)

Thanks for pointing out that detail on the local soft-argmax! I wouldn’t have noticed otherwise, and it is important so that the values outside the window do not count towards that weighted sum.

Thanks again for everything!!!

1 Like