Hi Stark!
You can’t, because it doesn’t make sense to. one_hot()
isn’t
usefully differentiable, so a loss function that uses* it won’t be either.
one_hot()
takes a torch.int64
argument and returns a torch.int64
result. Pytorch doesn’t even permit such integer-valued tensors to
have requires_grad = True
(because it doesn’t make sense).
*) To be more precise, a loss function could depend on the result
of one_hot()
and also on the results of some differentiable tensor
operations. Such a loss function would be usefully differentiable
in that you could backpropagate through the differentiable tensor
operations. But you wouldn’t be able to backpropagate through the
one-hot()
part of the computation (nor anything upstream of it).
Best.
K. Frank