Skip dimensions using index_select/index_copy

I am trying to copy a small tensor into another tensor at a list of indices.

I have three tensors:

a # Shape (batch, k, feat)
b # Shape (batch, feat)
c # Shape (batch,1)

I want to plug b into a using c. Right now I am using a for loop

for batch in a.shape[0]:
  a[batch, c[batch]] = b[batch]

I have tried to speed this up using index_copy, index_select, and gather, but I can’t figure out how to do this using vectorized torch functions. I believe this is similar to this unanswered question: Index_copy_ on several dimensions

There is probably a better way, but this seems to leverage pytorch vectorization

idxs_0 = torch.arange(a.shape[0])
idxs_1 = c.squeeze()
a[idxs_0, idxs_1] = b[idx_0]