Given k sets of features of each size (m, n), I want to output a pooled feature of size (1, m, n) which basically represents the max feature among each row across the k sets. I’m trying to implement equation 2 from this paper: https://arxiv.org/pdf/1606.00061.pdf.

**Example:**

For k=3, m=3, n=4 input:

```
tensor([[[ 3.2227, -12.6333, 3.4998, 3.0813],
[ 1.1984, 12.3766, 11.1678, -2.4728],
[ 0.0000, 0.0000, 0.0000, 0.0000]],
[[ 4.4212, -0.2568, 14.6676, 0.6086],
[ 1.1984, 12.3766, 11.1678, -2.4728],
[ 0.0000, 0.0000, 0.0000, 0.0000]],
[[ 4.4212, -0.2568, 14.6676, 0.6086],
[ 1.1984, 12.3766, 11.1678, -2.4728],
[ 0.0000, 0.0000, 0.0000, 0.0000]]])
```

I expect the output max row features across the k sets as follows:

```
tensor([[ 4.4212, -0.2568, 14.6676, 0.6086],
[ 1.1984, 12.3766, 11.1678, -2.4728],
[ 0.0000, 0.0000, 0.0000, 0.0000]])
```

**Explanation:**

The metric I use to define a max feature, is using norm as follows:

```
torch.norm(input, dim=2, keepdim=True)
tensor([[[13.8466],
[16.8953],
[ 0.0000]],
[[15.3337],
[16.8953],
[ 0.0000]],
[[15.3337],
[16.8953],
[ 0.0000]]])
```

For row 1, the max features across all sets is from set 2 (or 3, since 3 also has same norm).’

For row 2, the max features across all sets is from set 1 (or even 2, 3, since they all have same norm).

For row 3, the max features across all sets is from set 1 (or even 2, 3, since they all have same norm).

This forms the output as shown above.

I was able to use torch.argmax to get the indices from the norm tensor, but I am unable to use these max indices to select to the corresponding input rows.