One hot a multi-dim tensor?

Hello all, I have an n*m tensor with float values on which I would like to perform a row-wise one-hot operation, that is every row should be a one hot vector.

What’s the most efficient way to accomplish this?

I’ve been doing:

 _, indices = torch.max(my_matrix, 1)
binary_out = torch.nn.functional.one_hot(indices, num_classes=nclasses)

The code looks alright. You could alternatively use torch.zeros(batch_size, num_classes).scatter_(1, indices, 1), but that shouldn’t be faster than F.one_hot.