How to learn the index of tensor?

Hi, how can I use PyTorch to learn index?
As the code follow,

input=torch.randn((3,3))
index=torch.Tensor([1,2])
input = nn.Parameter(input)
index = nn.Parameter(index)
index=index.long()
output=input[index,:]
output_sum=output.sum()
output_sum.backward()
print(“input_grad”, input.grad)
print(“index_grad”, index.grad)

I got output,
input_grad tensor([[0., 0., 0.],
[1., 1., 1.],
[1., 1., 1.]])
index_grad None

Obviously, index has no gradient as it’s turned to integer.
How can I set the index to learnable? Thanks!

You can’t backward the gradient to index, since output_sum is non-derivable to index.
Can you explain why you want index to be learnable ?

Maybe my network is strange, but learnable index is necessary to complete it. Do you know any alternative method to learn the index?