Get the last token tensor according the seq_len from a padded tensor

Hi,

Poor to my English, I can not google well about my question.

I have a padded tensor tensor_a.shape=[B, S, H] , and I have a seq_len.shape = [B]

tensor_a =
tensor([[[0.8906, 0.5816, 0.3081],
[0.7228, 0.2029, 0.9706],
[0.8582, 0.5991, 0.2073],
[0.5673, 0.2016, 0.5965]],

    [[0.7556, 0.4601, 0.5045],
     [0.5717, 0.2664, 0.0891],
     [0.2461, 0.0900, 0.6039],
     [0.6516, 0.0964, 0.6036]]])

seq_len= tensor([3, 1])

I want to get a new tensor

tensor([[ [0.8582, 0.5991, 0.2073],
[
[0.5717, 0.2664, 0.0891],
]])
I just want to get the last word’s tensor. What can I do?

This should work:

tensor_a = torch.tensor([[[0.8906, 0.5816, 0.3081],
                          [0.7228, 0.2029, 0.9706],
                          [0.8582, 0.5991, 0.2073],
                          [0.5673, 0.2016, 0.5965]],
                         
                         [[0.7556, 0.4601, 0.5045],
                          [0.5717, 0.2664, 0.0891],
                          [0.2461, 0.0900, 0.6039],
                          [0.6516, 0.0964, 0.6036]]])

seq_len = torch.tensor([2, 1])
res = tensor_a[torch.arange(tensor_a.size(0)), seq_len]
print(res)
# tensor([[0.8582, 0.5991, 0.2073],
#         [0.5717, 0.2664, 0.0891]])