Retaining grad_fn for one-hot encoded tensors

Hi Stark!

You can’t, because it doesn’t make sense to. one_hot() isn’t
usefully differentiable, so a loss function that uses* it won’t be either.

one_hot() takes a torch.int64 argument and returns a torch.int64
result. Pytorch doesn’t even permit such integer-valued tensors to
have requires_grad = True (because it doesn’t make sense).

*) To be more precise, a loss function could depend on the result
of one_hot() and also on the results of some differentiable tensor
operations. Such a loss function would be usefully differentiable
in that you could backpropagate through the differentiable tensor
operations. But you wouldn’t be able to backpropagate through the
one-hot() part of the computation (nor anything upstream of it).

Best.

K. Frank

1 Like