Gather each row except specific column in a torch Tensor

Similar to Select specific columns of each row in a torch Tensor, But I want to gather except specific column.

For example, Given

input = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])

and

target = torch.tensor([0, 3, 2])

I want to get

output = torch.tensor([[2, 3, 4], [5, 6, 7], [9, 10, 12]])

Thanks.

Bro,

input = input[(1 - torch.nn.functional.one_hot(target, input.shape[1])).bool()].reshape(input.shape[0], -1)
2 Likes

Thanks, that’s exactly what I needed