Dear all,
I have a question about the memory consumption of tf.gather and pytorch equivalent implementation.
So I have a tensor A of size: [Nx32], and an B array of [Nx3], every row of B is contains the row indices I wanted to select from A, so the final output I want will be an array of size:[Nx3x32]
for example
A = torch.rand((6,32))
B = torch.tensor([[1,2,3],[2,4,5],[2,6,5],[1,5,3],[1,2,5],[4,5,6]]) - 1
#in tensorflow, I can just do tf.gather(A,B,axis=0), then it will give me output of size[6x3x32]
# in pytorch the way I come up with is
A_enlarged = tf.stack([A]*6,dim=0)
output = A[:,B]
#this should gives me the same result
But for my implementation N is quite large, so the A_enlarged will consume a lot of memory. Is there tf.gather equivalent function in pytorch? Or tf.gather consume the same amount of memory what I have implementation in pytorch? I am looking forward to your reply. Thank you very much