Hi, I have a question of getting certain indexed element from a tensor.
For example, I have a tensor
a = torch.arange(24).to(torch.float).reshape(2, 3, 4)
a.requires_grad = True
print(a)
# output
tensor([[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.]],
[[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.]]], requires_grad=True)
I also have a list of indexes, and I want to extract the element from tensor a
with the indexes. Therefore, I tried the following way:
index = [(0, 0, 0), (0, 1, 1)]
b = torch.Tensor([a[i[0], i[1], i[2]] for i in index_list])
# output
tensor([0., 5.])
However, when I check b.requires_grad
, it becomes False
.
I want to know how to achieve the same goal without loss of “requires_grad” property.
Thanks ahead!