I have been going through Pytorch documentation in search for a way to do an efficient per-index replacement of values inside a tensor. The problem:
dest - 4D tensor (N, H, W, C1) that I want to update
idxs - 4D tensor of indexes I want to access (N, H, W, C2)
source - 4D tensor of values I want to put in z at idxs (N, H, W, C2).
In practice dest a minibatch of per-pixel probabilities, where C1 is a number of distinct categories and C2 is top 5 updates from source. Note that C1 != C2.
I have been looking at torch.copy_index_ but I can’t seem to find a way to broadcast it along N, H, W dimensions.
My (working) solution with a for loop:
def col_wise_replace(dest, idxs, source): ''' replace probability stacks ''' for k in range(n): for i in range(h): for j in range(w): dest[k,i,j].index_copy_(0, idxs[k,i,j], source[k,i,j])
This solution is slow. Any numpy / torch-like way I can achieve the same?