Sure, just for small tensor purpose, I make dim[1] of A be 9 instead of 32. New code will be as below.
The is an error on previous code. I have corrected it.
A = torch.rand(6,9)
# A: tensor([[0.4918, 0.9813, 0.6889, 0.3126, 0.3797, 0.5551, 0.8716, 0.2215, 0.5576],
# [0.7467, 0.4867, 0.6318, 0.4681, 0.1874, 0.9327, 0.6479, 0.0885, 0.9818],
# [0.0431, 0.3718, 0.4447, 0.6178, 0.8779, 0.5084, 0.4198, 0.7450, 0.3994],
# [0.5677, 0.7410, 0.2781, 0.9504, 0.1205, 0.6734, 0.2124, 0.4071, 0.7975],
# [0.0432, 0.1230, 0.7874, 0.3125, 0.2123, 0.8771, 0.1967, 0.6190, 0.6263],
# [0.0129, 0.1492, 0.0374, 0.2873, 0.4654, 0.2602, 0.7498, 0.4395, 0.7209]])
B = torch.tensor([[1, 2, 3], [0, 3, 4], [4, 5, 2], [1, 4, 5]])
#expected output with output = A[B,:]
output = A[B,:]
#output = tensor([[[0.7467, 0.4867, 0.6318, 0.4681, 0.1874, 0.9327, 0.6479, 0.0885,
# 0.9818],
# [0.0431, 0.3718, 0.4447, 0.6178, 0.8779, 0.5084, 0.4198, 0.7450,
# 0.3994],
# [0.5677, 0.7410, 0.2781, 0.9504, 0.1205, 0.6734, 0.2124, 0.4071,
# 0.7975]],
# [[0.4918, 0.9813, 0.6889, 0.3126, 0.3797, 0.5551, 0.8716, 0.2215,
# 0.5576],
# [0.5677, 0.7410, 0.2781, 0.9504, 0.1205, 0.6734, 0.2124, 0.4071,
# 0.7975],
# [0.0432, 0.1230, 0.7874, 0.3125, 0.2123, 0.8771, 0.1967, 0.6190,
# 0.6263]],
# [[0.0432, 0.1230, 0.7874, 0.3125, 0.2123, 0.8771, 0.1967, 0.6190,
# 0.6263],
# [0.0129, 0.1492, 0.0374, 0.2873, 0.4654, 0.2602, 0.7498, 0.4395,
# 0.7209],
# [0.0431, 0.3718, 0.4447, 0.6178, 0.8779, 0.5084, 0.4198, 0.7450,
# 0.3994]],
#
# [[0.7467, 0.4867, 0.6318, 0.4681, 0.1874, 0.9327, 0.6479, 0.0885,
# 0.9818],
# [0.0432, 0.1230, 0.7874, 0.3125, 0.2123, 0.8771, 0.1967, 0.6190,
# 0.6263],
# [0.0129, 0.1492, 0.0374, 0.2873, 0.4654, 0.2602, 0.7498, 0.4395,
# 0.7209]]])
This result is the same as tf.gather(A,B,axis=1) theoretically. However, is the memory consumption of the these two the same?