How to collect coordinates within some bounding boxes efficiently by PyTorch

I have a bounding boxes variable, whose size is [num_boxes,4], and I want to get all the coordinates within all the boxes’ area. Do you have any idea to do this job efficiently by PyTorch?
Thank you!

Are the bounding boxes are of same of size?

No, they have different sizes and can appear in random positions

Since the bounding boxes are not of same size, I can’t think of any way to vectorize :(:disappointed:

Thank you anyway:grinning:

I think that might be solved with numpy.mgrid, but I am not sure if there is any PyTorch equivalent. Nevertheless, numpy is usually a faster choice in python.

Lets assume a torch.Tensor bboxes of size [# bboxes, 4] where the last dimension represents the dimensions of a single bounding box in the order left, top, bottom, right.

Lets also assume there is a torch.Tensor points of size [# points, 2] where the last dimension represents the dimensions of a single point in the order x, y.

All coordinates are in a system where the origin is located at the top-left corner, and the axis spans right-down.

import torch

# Bounding boxes and points.
bboxes = [[0, 0, 30, 30], [50, 50, 100, 100]]
bboxes = torch.FloatTensor(bboxes)
points = [[99, 50], [101, 0], [30, 30]]
points = torch.FloatTensor(points)

# Keep a reference to the original strided `points` tensor.
old_points = points

# Permutate all points for every single bounding box.
points = points.unsqueeze(1)
points = points.repeat(1, len(bboxes), 1)

# Create the conditions necessary to determine if a point is within a bounding box.
# x >= left, x <= right, y >= top, y <= bottom
c1 = points[:, :, 0] <= bboxes[:, 2]
c2 = points[:, :, 0] >= bboxes[:, 0]
c3 = points[:, :, 1] <= bboxes[:, 3]
c4 = points[:, :, 1] >= bboxes[:, 1]

# Add all of the conditions together. If all conditions are met, sum is 4.
# Afterwards, get all point indices that meet the condition (a.k.a. all non-zero mask-summed values)
mask = c1 + c2 + c3 + c4
mask = torch.nonzero((mask == 4).sum(dim=-1)).squeeze()

# Select all points that meet the condition.
print(old_points.index_select(dim=0, index=mask))

Enjoy! :).

Thank you very much, it works:smile: