Problem about batch index

I have a tensor with size of (64, 20,128) (actually it’s decoder outputs).
Now , I have index tensor and its size is (64) (it’s a sequence lengths).
I want to select the last real decoder outputs, and i think the shape should be (64, 128).
I have tried torch.gather and torch.index_select, but it didn’t work…
Can anybody help me?

Would this code snippet work for you?

x = torch.randn(64, 20, 128)
idx = torch.randint(1, 20, (64,))
result = x[torch.arange(x.size(0)), idx]

Yes, it really works for me. Thank you very much.