Sample tensor at specific coordinates using Spatial Transformer Networks


I have two tensors:

  • features - with shape [batch_size, channels, height, width]
  • coordinates - with shape [batch_size, num_coordinates, 2]

The coordinates tensor has values in range [-1,1] and every row indicates a location in the features tensor (height and width are also considered to be normalized between -1 and 1]).

I would like to sample the features tensor at the coordinates described by the coordinates tensor using STNs. The result must have the shape: [batch_size, channels, num_coordinates].

Does anyone has a simple solution code?

Thank you!

Solved with:

coordinates = coordinates.unsqueeze(1)
features = torch.nn.functional.grid_sample(features, coordinates, align_corners=True)
features = features.squeeze(2)