I have a tensor A with shape: torch.Size([10, 20, 30])
I also have a tensor B with shape: torch.Size([1, 2, 3, 5])
I want to gather values from tensor A using indices in tensor B. The gathering should occur along the last dimension, such that the 5 values in tensor B correspond to the indices of the last dimension of tensor A. So B.max()<30
.
The resulting tensor should have the shape torch.Size([1, 2, 3, 5, 10, 20])
.
Currently I can achieve this result by repeating both tensors such that A becomes torch.Size([10, 20, 1, 2, 3, 30])
and B becomes torch.Size([10, 20, 1, 2, 3, 5])
and then performing a torch.gather() operation on them. The dimensions in reality are much larger and this takes too much memory. I was wondering if there is a way to achieve the same without repeating the tensors.