Seq2Seq: Get first element of every document in a batch

For an encoder decoder task, I want to get every first element in the whole batch to feed it into the decoder (first element of every document is always the tag).
Let’s say I have this output tensor:

tensor([[   1, 1870, 1871,  ...,    0,    0,    0],
        [   1, 1159, 1160,  ...,    0,    0,    0],
        [   1, 1159, 1500,  ..., 1651, 1170, 1167],
        [   1, 1172, 1173,  ...,    0,    0,    0]])

How can I get:

tensor([[   1  ],
        [   1  ],
        [   1  ],
        [   1  ]])

I am following this tutorial and the part which doesn’t work for me is under “9. Seq2Seq (Encoder + Decoder) Code Implementation” in line 22.
Has anyone an idea? Thank you!

I’m not sure I understand the use case completely, but wouldn’t indexing the first column work?

Yes I think so, I just don’t know how to do that.
But I think I got the right solution (target_tensor is the given tensor above):
target_tensor[torch.arange(target_tensor.size(0)), 0]

Your approach would work but alternatively you could also use target_tensor[:, 0].

1 Like

Ah thank you, that is way easier!