No gradient for making the max value 1 and others 0

Hi,

I would like to keep the maximum value to be 1 and others to be 0 in a tensor. But the backward function does not produce gradient for the tensor.

I attached the example. Thanks!

import torch

x = torch.Tensor([[0, 1, 2, 3]]).requires_grad_()

b = (x==x.max(dim=1,keepdim=True)[0]).type(torch.FloatTensor).requires_grad_()

out = b.sum()

out.backward()

x.grad #why is there no grad for x?

b.grad

Hi,

The problem is that you have a non-floating point Tensor that you set by hand to requires_grad_() again.
You will need to only use differentiable ops here.

To keep the values between 0/1 I would recommend the softmax function that will make sure you get non-zero gradients everywhere (but you won’t be able to reach 0 or 1).
Or you can use clamp. But then the gradients for the values <0 and >1 will be 0.

Thanks @albanD. Does it mean “b = (x==x.max(dim=1,keepdim=True)[0]).type(torch.FloatTensor)” is actually not differentiable and so that x.grad is none?

And is there a way to keep the maximum value to be 1 and the others to be 0 in a tensor but the tensor still could have gradient? I found a similar question here Set Max value to 1, others to 0 but it seems one hot is also not differentiable.

“b = (x==x.max(dim=1,keepdim=True)[0]).type(torch.FloatTensor)” is actually not differentiable and so that x.grad is none?

If returns a Tensor that does not require gradients even though the input does. This means it is not differentiable yes.

And is there a way to keep the maximum value to be 1 and the others to be 0 in a tensor but the tensor still could have gradient?

If you think about gradient as “how much the output change if I change this input a little”, then you can see that if the output of your function is “the maximum value to be 1 and the others to be 0”, then all the gradients will be 0. Because when you change any input a little (if the maximum one does not change), then all the outputs will remain the same.

Very clear explanation, thanks @albanD !