Extract batchwise multiple coordinates with pytorch functions (similar to tf.gather_nd)

Hi everyone,

I have a problem that I try to solve, and I have real problems coming up with a solution so far. In tensorflow, one could use tf.gather_nd for it:

Let’s image I have a batch tensor img_feat of size (16, 256, 16, 16). In addition, I have an x and y tensor, each of size (16, 150), so that I have 150 points, each with an x and a y coordinate.

What I want to do is to extract for each batch (16), each of the features (256) for each point (150) which has a x and y coordinate in the 16x16 images of img_feats More precise, this is exactly what I want to have:

feats = torch.stack([torch.stack([img_feat[i,:,x[i,j].long(), y[i,j].long()] for j in range(x.shape[1])]) for i in range(batch_size)])

My problem is, that I want to implement it with torch functions (e.g. gather, select_index, etc.) so that it can be parallelised when using GPUs.

Does anyone can come up with a solution? It would be so awesome!

Thanks a lot in advance!

1 Like

I think two gather calls will achieve, what you want:

batch_size = 16
c, h, w = 256, 16, 16
nb_points = 150
img_feat = torch.randn(batch_size, c, h, w)
x = torch.empty(batch_size, nb_points, dtype=torch.long).random_(h)
y = torch.empty(batch_size, nb_points, dtype=torch.long).random_(w)
x = x[:, None, :, None].expand(-1, c, -1, w)
y = y[:, None, None, :].expand(-1, c, nb_points, -1)

points = torch.gather(torch.gather(img_feat, 2, x), 3, y)

Dear ptrblck,

thanks a lot for your help! I think you are nearly there, however this is not yet what I need:

With your method, points has the shape of (batch_size, c, nb_points, nb_points)

What I want is a resulting tensor of shape (batch_size, nb_points, c). In other words, For each x and y coordinate, I want to extract the whole c features for that coordinate. As I have 150 points per batch, it should result in the above mentioned size.

I attach again my simple list-style code which does the desired behaviour:

feats = torch.stack([torch.stack([img_feat[i,:,x[i,j].long(), y[i,j].long()] for j in range(x.shape[1])]) for i in range(batch_size)])

Thanks so much! I am new to pytorch, and you guys help me a lot!

Hey guys,

feel free to correct me, but I think this might be the correct solution:

batch_size = 16
c, h, w = 256, 16, 16
nb_points = 150
img_feat = torch.randn(batch_size, c, h, w)
x = torch.empty(batch_size, nb_points, dtype=torch.long).random_(h)
y = torch.empty(batch_size, nb_points, dtype=torch.long).random_(w)
x = x[:, None, :, None].expand(-1, c, -1, w)
y = y[:, None, :, None].expand(-1, c, -1, 1)

points = torch.gather(torch.gather(img_feat, 2, x), 3, y)

This way, it should collect only 1 (hopefully the correct ;)) value for each of the nb_points * w extracted values from the first gather in the second gather step.

That should be correct, right?

Oh I’m sorry. I completely misunderstood your use case.

This should work and is much easier:

img_feat = torch.randn(batch_size, c, h, w)
x = torch.empty(batch_size, nb_points, dtype=torch.long).random_(h)
y = torch.empty(batch_size, nb_points, dtype=torch.long).random_(w)
points =  img_feat[torch.arange(batch_size)[:, None], :, x, y]
feats = torch.stack([torch.stack([img_feat[i,:,x[i,j].long(), y[i,j].long()] for j in range(x.shape[1])]) for i in range(batch_size)])
print((points == feats).all())
2 Likes

awesome, thanks so much ptrblck!

1 Like