Given list of index, obtain the corresponding element from a tensor

Hi, I have a question of getting certain indexed element from a tensor.

For example, I have a tensor

a = torch.arange(24).to(torch.float).reshape(2, 3, 4)
a.requires_grad = True
print(a)
# output
tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]], requires_grad=True)

I also have a list of indexes, and I want to extract the element from tensor a with the indexes. Therefore, I tried the following way:

index = [(0, 0, 0), (0, 1, 1)]
b = torch.Tensor([a[i[0], i[1], i[2]] for i in index_list])
# output
tensor([0., 5.])

However, when I check b.requires_grad, it becomes False.
I want to know how to achieve the same goal without loss of “requires_grad” property.

Thanks ahead!

Rewrapping a tensor in torch.Tensor will detach it from the computation graph, so you might want to use torch.cat or torch.stack instead.

1 Like

Yes, I tried cat and it works. Thanks a lot!