jinfagang
(Jin Tian)
1
I want index a 17,48 tensor with a 17 tensor and get another 17 tensor,
which is, get the 30th from first row, get 46th from second row…
what should I do?
This code should work:
x = torch.randn(17, 48)
idx = torch.randint(0, 48, (17,))
x[torch.arange(x.size(0)), idx]