Hi, how can I use PyTorch to learn index?
As the code follow,
input=torch.randn((3,3))
index=torch.Tensor([1,2])
input = nn.Parameter(input)
index = nn.Parameter(index)
index=index.long()
output=input[index,:]
output_sum=output.sum()
output_sum.backward()
print(“input_grad”, input.grad)
print(“index_grad”, index.grad)
I got output,
input_grad tensor([[0., 0., 0.],
[1., 1., 1.],
[1., 1., 1.]])
index_grad None
Obviously, index has no gradient as it’s turned to integer.
How can I set the index to learnable? Thanks!