Hi,
I have a high-dimensional tensor whose values I’m trying to update along specific indices using another tensor like so:
N = 5
C = 2
K = 9
J = 4
x = torch.randn(N, C, K, J) # Shape: (5, 2, 9, 4)
new_values = torch.randn((N, C)) # Shape: (5, 2)
index1 = torch.arange(N)
index3 = torch.randint(K, (N,))
index4 = torch.randint(J, (N,))
x[index1 , :, index3, index4] = new_values
This works but I’m looking for a similar operation that’d take the indices as a tuple (the indices are generated dynamically so I can’t hardcode the indexing of x like above) and would work for an arbitrary number of dimensions as in the example below:
mystery_update_operation(x, (index1, index3, index4), new_values)
I have looked into scatter_
and index_put_
but can’t figure out how to make them work in this case.
Cheers.