I’m looking for the most efficient way to perform this operation
weight_range =  for i in range(id_min.shape): weight_range.append(torch.arange(id_min[i], id_max[i], device=device)) weight_range = torch.stack(weight_range).flatten()
Where id_min and id_max are 1D torch tensors representing the minimum and maximum range (the size id_min:id_max is always the same). This operation is performed in the forward step and is very time consuming. It’s basically grabbing a small subset of weights for each each batch and it will be a different subset depending on the input.