How could I delete one element from each row of a tensor

Hi,

Suppose these is one 2d tensor:

ten = torch.randn(4, 5)

Now I need to remove one element from each row of that tensor according to another tensor:

ids = torch.tensor([2, 1, 4, 1])

which means that remove the ten[0, 2], ten[1, 1], ten[2, 4], ten[3, 1]from the tensor, and results in a4 x 4` tensor. How could I do this please ?

This should work:

ten = torch.randn(4, 5)
ids = torch.tensor([2, 1, 4, 1])
mask = torch.ones_like(ten).scatter_(1, ids.unsqueeze(1), 0.)
res = ten[mask.bool()].view(4, 4)

Thanks for replying !! what if the tensor has layout of nchw and I need to remove one element from dim=1 ?

ten = torch.randn(2, 4, 5, 5)
ids = torch.tensor([2, 1, 4, 1])
....
res = ... # should have shape of [2, 3, 5, 5]

If you want the output shape [2, 3, 5, 5], so you want to drop the indices of dim=1, your ids can only be of length 2 and with a max value of 3.

I got this running, but i dont think its the best way xD

ten = torch.randn(2, 4, 5, 5)
ids = torch.tensor([3, 1])
mask = torch.ones_like(ten).scatter_(1, ids[:, None, None, None].repeat(1, 1, 5, 5), 0.)
res = ten[mask.bool()].view(2, 3, 5, 5)
1 Like

HI, thanks for replying this. In my case, the ids is not a (2, ) index tensor, it has shape of (2, 5, 5).

ten = torch.randn(2, 4, 5, 5)
ids = torch.randint(0, 4, (2, 5, 5))
.....
res = ... # has shape of (2, 3, 5, 5)

I am afraid there should not be broadcast operation here. The (2, 5, 5) tensor of ids means that the element of the [:, id, :, :] should be removed according to the ids.

How could I do this please?