Efficient subset selection

I’m looking for the most efficient way to perform this operation

weight_range = []
for i in range(id_min.shape[0]):
    weight_range.append(torch.arange(id_min[i], id_max[i], device=device))
weight_range = torch.stack(weight_range).flatten()

Where id_min and id_max are 1D torch tensors representing the minimum and maximum range (the size id_min:id_max is always the same). This operation is performed in the forward step and is very time consuming. It’s basically grabbing a small subset of weights for each each batch and it will be a different subset depending on the input.

(id_min.unsqueeze(-1) + torch.arange(size,device=id_min.device)).view(-1)

Thanks! That’s what I’m looking for.