Perform maxpooling across sets of features vectors

Given k sets of features of each size (m, n), I want to output a pooled feature of size (1, m, n) which basically represents the max feature among each row across the k sets. I’m trying to implement equation 2 from this paper: https://arxiv.org/pdf/1606.00061.pdf.

Example:
For k=3, m=3, n=4 input:

tensor([[[  3.2227, -12.6333,   3.4998,   3.0813],
         [  1.1984,  12.3766,  11.1678,  -2.4728],
         [  0.0000,   0.0000,   0.0000,   0.0000]],

        [[  4.4212,  -0.2568,  14.6676,   0.6086],
         [  1.1984,  12.3766,  11.1678,  -2.4728],
         [  0.0000,   0.0000,   0.0000,   0.0000]],

        [[  4.4212,  -0.2568,  14.6676,   0.6086],
         [  1.1984,  12.3766,  11.1678,  -2.4728],
         [  0.0000,   0.0000,   0.0000,   0.0000]]])

I expect the output max row features across the k sets as follows:

tensor([[  4.4212,  -0.2568,  14.6676,   0.6086],
         [  1.1984,  12.3766,  11.1678,  -2.4728],
         [  0.0000,   0.0000,   0.0000,   0.0000]])

Explanation:
The metric I use to define a max feature, is using norm as follows:

torch.norm(input, dim=2, keepdim=True)
tensor([[[13.8466],
         [16.8953],
         [ 0.0000]],

        [[15.3337],
         [16.8953],
         [ 0.0000]],

        [[15.3337],
         [16.8953],
         [ 0.0000]]])

For row 1, the max features across all sets is from set 2 (or 3, since 3 also has same norm).’
For row 2, the max features across all sets is from set 1 (or even 2, 3, since they all have same norm).
For row 3, the max features across all sets is from set 1 (or even 2, 3, since they all have same norm).
This forms the output as shown above.

I was able to use torch.argmax to get the indices from the norm tensor, but I am unable to use these max indices to select to the corresponding input rows.

Solved it with the help of this method How to select index over two dimension? and grouping my features across the sets (group1 - all row 1s among the k sets, group2 - all row 2s among the k sets). Then I apply the norm and argmax on these groups to achieve the desired output.

1 Like