Indexing a tensor using a batch of coordinates


  • batch_maps: 4D tensor of shape (B, H, W, C) of zeros
  • batch_label: 3D tensor of shape (B, N, 3)

where batch_label is a batch of B sets of N points, each containing (x,y,prob) where prob is the value of an array at that point.
what I want is to use batch_label to create B maps (fill batch_maps with points) , each map_i with values {prob}_i at coordinates {(x,y)}_i, where i goes from 1 to B.
When I tried using x and y as indices, I got all of the points from all of the batch in all of the naps. I managed to get my goal using a for loop on the batch dimension:
for iter, (im,label) in enumerate(dataloader):

map = torch.zeros_like(im)
label = label.unsqueeze(1)
vals = label[:,:,:,2]
x = label[:,:,:,0].type(torch.long)
y = label[:,:,:,1].type(torch.long)
for i in range(im.size()[0]):
… map[i,:,y[i,:,:],x[i,:,:]] = vals[i,:,:]

is there a way to do this without the for loop?

1 Like

You could try use torch.gather here.

Not sure how gather helps me here, i can’t assign to function call