#inputs
x = torch.tensor([[1, 2, 0],
[3, 0, 0],
[4, 5, 6]])
y = torch.tensor([2, 1, 3]) - 1
# creating one-hot mask
# borrowed from https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507/4?u=nikronic
mask = torch.LongTensor(len(y), len(y))
mask.zero_()
mask.scatter_(1, y.view(-1, 1), 1)
# applying mask to zero non desired values
res = x*mask
# extract values that are non zero in each row
res[tuple(res.nonzero().T)]

But I am not sure there is any better solution or not, it just came to my mind. I hope it helps!