Batch efficient random sampling with replacement given mask

I have a tensor of size BxNx2, where B is the batch size, N is the number of bounding boxes, and 2 is the two floats min, max scalars for ray directions that once scaled and translated by the ray origin position are the 2 intersections within those bounding boxes.

The tensor itself is sparse, thus there is a mask of size BxN that filters boxes with no intersections.

The objective is to perform batch sampling for efficient ray casting, by that I need to sample e.g. 64 relevant voxels / boxes where intersection occurs with replacement (as there can be less) so that resulting tensor has size Bx64x2.

Is there any efficient way of doing it? My idea is to create a tensor of ones_like BxN and based on the mask assign 1e-10 weights for non-intersections, subsequently, use torch multinational distribution to sample accordingly, but I feel it’s too much inefficient?