Range-based indexing based on another tensor

Suppose I have a tensor x of shape (B, U, T), and another tensor t of shape (B, U, 2). The last dimension in t denotes start and end indices for the corresponding T dimension of x. I want to create a flattened tensor y which contains only the “selected” values of x, where the selection is based on the ranges given in t. Assume that I know the total size of y in advance (this is easy to compute as y_size = (t[:, :, 1] - t[:, :, 0]).sum()).

A naive loop based implement would look something like this:

offset = 0
for b in range(B):
    for u in range(U):
        start, end = t[b, u]
        y[offset : offset + end - start] = x[b, u, start:end]
        offset += end-start

Of course, this is slow since the loop runs B*U times. I am trying to come up with a more efficient way of doing this, perhaps using index_select(), but I am stuck.

I would appreciate any help if someone can think of a faster solution.

I was able to come up with the following mask based approach, but perhaps there is a more efficient way.

import torch.nn.functional as F

indices = torch.zeros_like(x)
indices += F.one_hot(t[..., 0], num_classes=T)
indices -= F.one_hot(t[..., 1], num_classes=T)
indices = torch.cumsum(indices, dim=-1).bool()

y = torch.masked_select(x, indices)

Hi Desh!

You could also compute indices with a logical expression:

indices = torch.logical_and (torch.arange (T) >= t[:, :, 0:1], torch.arange (T) < 
t[:, :, 1:2])

To my eye, this version is a little more readable. (Because of broadcasting,
indices has the same shape as x.)

I don’t know whether this would have any efficiency benefits.


K. Frank

Thanks for your suggestion. It is indeed more readable than using one-hot and cumsum. I will try both methods and see if there it makes a difference in the throughput.

Update: After trying out both methods, I found that the method suggested by @KFrank is ~10% slower overall compared to the cumsum. I think this may be because it requires T point-wise comparisons, which is slower than just setting 2 values and doing cumsum on a single tensor.

I also found this related post, where @ptrblck suggests a similar approach.