Torch.gather() without repeating the tensor

I have a tensor A with shape: torch.Size([10, 20, 30])

I also have a tensor B with shape: torch.Size([1, 2, 3, 5])

I want to gather values from tensor A using indices in tensor B. The gathering should occur along the last dimension, such that the 5 values in tensor B correspond to the indices of the last dimension of tensor A. So B.max()<30 .

The resulting tensor should have the shape torch.Size([1, 2, 3, 5, 10, 20]).

Currently I can achieve this result by repeating both tensors such that A becomes torch.Size([10, 20, 1, 2, 3, 30]) and B becomes torch.Size([10, 20, 1, 2, 3, 5]) and then performing a torch.gather() operation on them. The dimensions in reality are much larger and this takes too much memory. I was wondering if there is a way to achieve the same without repeating the tensors.

Could you post your current code so we could use it as a reference for a potentially better implementation?

Sure, here it is:

A = torch.rand(10, 20, 30)
B = torch.rand(1, 2, 3, 5) 
B = (B*30).type(torch.long)

A = A[:, :, None, None, None, :].repeat(1, 1, *B.shape[:3], 1) # 10, 20, 1, 2, 3, 30
B = B[None, None, :, :, :, :].repeat(*A.shape[:2], 1, 1, 1, 1) # 10, 20, 1, 2, 3, 5

result = torch.gather(A, -1, B) # 10, 20, 1, 2, 3, 5

I’m also searching for a good solution here. However one workaround is to use expand instead of repeat because it doesn’t create a copy of the tensor.