Best way of multi-dimensional tensor indexing

Hello,

What is the best way to replace values of specific indices in a multi-dimensional tensor, given a smaller tensor of new values with a fewer number of dimensions, while not ruining the gradient flow?

So far I did it like that:

input = torch.rand(2, 3, 4, 4)

top_k_ind = torch.topk(input [:, :, 0, :], 5, largest=True)[1] # indeces only

for dim_0 in range(top_k_ind.shape[0]):
    for dim_1 in range(top_k_ind.shape[1]):
        input [dim_0, dim_1, 0, top_k_ind[dim_0 , dim_1 , :]] = 0 # 3d dim of input is not involved

However, I doubt this is the right way of doing it in PyTorch.

I think there must be a solution like that:

input[ top_k_ind[0], top_k_ind[1], 0, top_k_ind[2] ] = 0 # 0, 1, 2 - dimensions

Thank everyone in advance.