So I have and (N,3) shape tensor, I would like to have the each of the (1,3) shaped rows to be converted to a (1,3) shaped row with a 1 in the place that the maximum of the row is and a 0 in the other place.

So is there a function that changes the maximum entry of a tensor to a 1 and the rest to 0 along a given axis, not necessarily a N,3 tensor?