I am trying to copy a small tensor into another tensor at a list of indices.
I have three tensors:
a # Shape (batch, k, feat)
b # Shape (batch, feat)
c # Shape (batch,1)
I want to plug b
into a
using c
. Right now I am using a for loop
for batch in a.shape[0]:
a[batch, c[batch]] = b[batch]
I have tried to speed this up using index_copy
, index_select
, and gather
, but I can’t figure out how to do this using vectorized torch functions. I believe this is similar to this unanswered question: Index_copy_ on several dimensions