Fastest way to modify tensor using indices

I need to extract elements of a tensor (at present done using a boolean mask), do some operations on them, and replace in the original tensor. The basic operations look like this, where index is a boolean tensor.

new = out[index]

# Do some operations on new here 

rt = out.clone()
rt[index] = new

Is there a faster way of doing this? In the test below, training time is roughly doubled (though depends on how many elements of index are True)

class test(nn.Module):

    def __init__(self):
        super(test, self).__init__()
        self.a = nn.Parameter(torch.ones([100, 200]))
    
    def forward(self, x, index=None):
    
        out = torch.matmul(self.a, x)
    
        if index is not None:
            new = out[index]     
            rt = out.clone()   
            rt[index] = new
        else:
            rt = out
    
        return rt

model = test()
opt = torch.optim.SGD(lr=0.01, params=model.parameters())
loss = torch.nn.MSELoss()

ixs = (torch.arange(100) > 50)

st = time.time()
for _ in range(1000):

    x = torch.ones([200, 400]) + 10
    target = torch.zeros([100, 400])

    out = model(x, ixs)

    l = loss(out, target)

    l.backward()
    opt.step()
    model.zero_grad()

print(time.time() - st)

Hi,

I’m afraid this is expected.
Your original forward is only a matmul (the most optimized op we have) and you add 3 more operations.

You’ll get better performances if you use indices and index_select/index_copy or gather/scatter if the number of indices is small.

Thanks! Will have a go with index_select/index_copy and make a comparison. It’s the speed of extracting and re-inserting the elements of the tensor I’m interested in! Interesting that using indices is quicker with a small number.

Well the complexity of the operation with a mask is of the size of the Tensor. Because for every entry, it needs to check the mask value.
For the indices, it just reads the values at the given indices, so the complexity is the number of indices.

Makes sense - thanks a lot!

The other aspect I’m wondering about in the above e.g. is that if I change the line

ixs = (torch.arange(100) > 50)

to

ixs = (torch.arange(100) > 99)

I get back to the baseline speed. So it seems the slowdown is not the indexing directly, but due to a copy being made of the elements of the matmul output where the indices that are True, and operations on this copy (hence the difference depending on how many elements of ixs are True). Is that kind of correct?!

In your code sample, you actually clone the input, not only the element you index.
But anyway, the smaller the indexing, the faster it’s gonna be.