Equivalence of tf.gather in pytorch with same memory consumption?

Dear all,

I have a question about the memory consumption of tf.gather and pytorch equivalent implementation.
So I have a tensor A of size: [Nx32], and an B array of [Nx3], every row of B is contains the row indices I wanted to select from A, so the final output I want will be an array of size:[Nx3x32]
for example

A = torch.rand((6,32))
B = torch.tensor([[1,2,3],[2,4,5],[2,6,5],[1,5,3],[1,2,5],[4,5,6]]) - 1
#in tensorflow, I can just do tf.gather(A,B,axis=0), then it will give me output of size[6x3x32]
# in pytorch the way I come up with is 
A_enlarged = tf.stack([A]*6,dim=0)
output = A[:,B]
#this should gives me the same result

But for my implementation N is quite large, so the A_enlarged will consume a lot of memory. Is there tf.gather equivalent function in pytorch? Or tf.gather consume the same amount of memory what I have implementation in pytorch? I am looking forward to your reply. Thank you very much

Hi,

Can you share what output you expect to get on a small Tensor here?

Sure, just for small tensor purpose, I make dim[1] of A be 9 instead of 32. New code will be as below.
The is an error on previous code. I have corrected it.

A = torch.rand(6,9)
# A: tensor([[0.4918, 0.9813, 0.6889, 0.3126, 0.3797, 0.5551, 0.8716, 0.2215, 0.5576],
#        [0.7467, 0.4867, 0.6318, 0.4681, 0.1874, 0.9327, 0.6479, 0.0885, 0.9818],
#        [0.0431, 0.3718, 0.4447, 0.6178, 0.8779, 0.5084, 0.4198, 0.7450, 0.3994],
#        [0.5677, 0.7410, 0.2781, 0.9504, 0.1205, 0.6734, 0.2124, 0.4071, 0.7975],
#        [0.0432, 0.1230, 0.7874, 0.3125, 0.2123, 0.8771, 0.1967, 0.6190, 0.6263],
#        [0.0129, 0.1492, 0.0374, 0.2873, 0.4654, 0.2602, 0.7498, 0.4395, 0.7209]])
B = torch.tensor([[1, 2, 3], [0, 3, 4], [4, 5, 2], [1, 4, 5]])

#expected output with output = A[B,:]
output = A[B,:]
#output = tensor([[[0.7467, 0.4867, 0.6318, 0.4681, 0.1874, 0.9327, 0.6479, 0.0885,
      #    0.9818],
      #   [0.0431, 0.3718, 0.4447, 0.6178, 0.8779, 0.5084, 0.4198, 0.7450,
     #     0.3994],
     #    [0.5677, 0.7410, 0.2781, 0.9504, 0.1205, 0.6734, 0.2124, 0.4071,
     #     0.7975]],

   #     [[0.4918, 0.9813, 0.6889, 0.3126, 0.3797, 0.5551, 0.8716, 0.2215,
   #       0.5576],
   #      [0.5677, 0.7410, 0.2781, 0.9504, 0.1205, 0.6734, 0.2124, 0.4071,
   #       0.7975],
   #      [0.0432, 0.1230, 0.7874, 0.3125, 0.2123, 0.8771, 0.1967, 0.6190,
   #       0.6263]],

   #     [[0.0432, 0.1230, 0.7874, 0.3125, 0.2123, 0.8771, 0.1967, 0.6190,
   #       0.6263],
  #       [0.0129, 0.1492, 0.0374, 0.2873, 0.4654, 0.2602, 0.7498, 0.4395,
  #        0.7209],
  #       [0.0431, 0.3718, 0.4447, 0.6178, 0.8779, 0.5084, 0.4198, 0.7450,
  #        0.3994]],
#
  #      [[0.7467, 0.4867, 0.6318, 0.4681, 0.1874, 0.9327, 0.6479, 0.0885,
  #        0.9818],
  #       [0.0432, 0.1230, 0.7874, 0.3125, 0.2123, 0.8771, 0.1967, 0.6190,
 #         0.6263],
 #        [0.0129, 0.1492, 0.0374, 0.2873, 0.4654, 0.2602, 0.7498, 0.4395,
#          0.7209]]])


This result is the same as tf.gather(A,B,axis=1) theoretically. However, is the memory consumption of the these two the same?

There are no extra concatenation here and it just reads the values and write to the result when you do indexing like that. So I would guess yes, they will use the same amount of memory.

Just a follow up question. If the B matrix is a tuple of index array. A[B,:] won’t work right? i.e. B=([1,3,4,5],[2,3],[1,3,7])

I am not a specialist of this, but it should follow the advanced indexing semantic from numpy described here.