Instead of the for loop to create your selection_tensor
, you could also use indexing:
selection_tensor[torch.arange(selection.size(0)), indices.t()] = 1
Instead of the for loop to create your selection_tensor
, you could also use indexing:
selection_tensor[torch.arange(selection.size(0)), indices.t()] = 1