Replace elements of tensor with variable number of zeros to along a certain dimension

Suppose there is an array A = tensor([[0.4869, 0.5144, 0.9086, 0.6139],
[0.5103, 0.8270, 0.4832, 0.8980],
[0.5234, 0.1135, 0.1037, 0.7451]])

And I want to replace the elements in each row with zeros, depending on another tensor t = tensor([0, 1, 3])

The output should be like out = tensor([[0.4869, 0.5144, 0.9086, 0.6139],
[0, 0.8270, 0.4832, 0.8980],
[0, 0, 0, 0.7451]])

I have already tried an implementation that uses the torch.gather function but that operation seems to consume a lot of memory and it runs into memory overflow when dealing with huge tensors.

Hi,

I think this thread can help you with this:

But I have not tested it on large tensors.

Bests

1 Like