I have a tensor A with shape: `torch.Size([10, 20, 30])`

I also have a tensor B with shape: `torch.Size([1, 2, 3, 5])`

I want to gather values from tensor A using indices in tensor B. The gathering should occur along the last dimension, such that the 5 values in tensor B correspond to the indices of the last dimension of tensor A. So `B.max()<30`

.

The resulting tensor should have the shape `torch.Size([1, 2, 3, 5, 10, 20])`

.

Currently I can achieve this result by repeating both tensors such that A becomes `torch.Size([10, 20, 1, 2, 3, 30])`

and B becomes `torch.Size([10, 20, 1, 2, 3, 5])`

and then performing a torch.gather() operation on them. The dimensions in reality are much larger and this takes too much memory. I was wondering if there is a way to achieve the same without repeating the tensors.