Reduce dimension by selecting along one domain inside another domain

I have a tensor of order 5 and want to reduce it to order 4 by selectively picking along the third domain based on the position in the second domain.

Let’s say my original tensor’s shape is (1, 16, 16, 8, 8) and I want to get (1, 16, 8, 8). So far, I am doing this as follows iteratively:

import torch

original_tensor = torch.randint(10, (1, 4, 4, 2, 2))
output_tensor = torch.empty((1, 4, 2, 2))
for i in range(output_tensor.shape[1]):
    output_tensor[:, i, :, :] = original_tensor[:, i, i, :, :].squeeze(1)

Putting the outputs in here would make the post very long, but you can execute the snippet on its own and print out both original_tensor and output_tensor. However for a second order into first order tensor consider this example:

tensor([[5, 8, 2, 5],
        [1, 4, 5, 8],
        [1, 5, 6, 0],
        [3, 2, 7, 2]])

tensor([5., 4., 6., 2.])

Iteratively, as you can imagine, this is way to slow given that as part of a layer I need to execute this a lot. I do though beforehand know (that is, it is constant) the indices which i want to keep. So I could precalculate some sort of mask if that simplifies it.

Please consider the added complexity due to the order and size of the real example I gave in the first snippet. A solution that just helps with the particular second example would not help.

I don’t quite understand this use case, but based on your loop approach, you could use arange to index the tensor via:

original_tensor[:, torch.arange(original_tensor.size(1)), torch.arange(original_tensor.size(2))]

which would result in the same output as output_tensor.

1 Like

Perfect! Thank you!

Am I correctly assuming that this works because the slicing associates the indices of the slices for the individual domains? Because intuitively it feels weird to just give a list of all available dimensions as a list, as I would have expected this to return the complete tensor.

Yes, this indexing will be equal to your loop, such that both dimensions are indexed with [0, 0], [1, 1], [2, 2], ..., if you use torch.arange.