Advanced/Fancy indexing across batches


Suppose I have a batch b consisting of BxCxHxW, and I have two tensors y and x of size BxHxW. my goal is to return a tensor b_sampled consists of B tensors which have been sampled according to the entries of x and y.

In other wants I want, b_sampled[0] = b[0,[x[0,:,;], y[0,:,:]]], b_sampled[1] = [1,[x[1,:,:],y[1,:,:]]],…

I can do this in a loop but there must be a better way.

1 Like

You can do this using:

def batched_index_select(input, dim, index):
    for ii in range(1, len(input.shape)):
        if ii != dim:
            index = index.unsqueeze(ii)
    expanse = list(input.shape)
    expanse[0] = -1
    expanse[dim] = -1
    index = index.expand(expanse)
    return torch.gather(input, dim, index)

Taken from here.