I need to extract elements of a tensor (at present done using a boolean mask), do some operations on them, and replace in the original tensor. The basic operations look like this, where index is a boolean tensor.
new = out[index]
# Do some operations on new here
rt = out.clone()
rt[index] = new
Is there a faster way of doing this? In the test below, training time is roughly doubled (though depends on how many elements of index are True)
class test(nn.Module):
def __init__(self):
super(test, self).__init__()
self.a = nn.Parameter(torch.ones([100, 200]))
def forward(self, x, index=None):
out = torch.matmul(self.a, x)
if index is not None:
new = out[index]
rt = out.clone()
rt[index] = new
else:
rt = out
return rt
model = test()
opt = torch.optim.SGD(lr=0.01, params=model.parameters())
loss = torch.nn.MSELoss()
ixs = (torch.arange(100) > 50)
st = time.time()
for _ in range(1000):
x = torch.ones([200, 400]) + 10
target = torch.zeros([100, 400])
out = model(x, ixs)
l = loss(out, target)
l.backward()
opt.step()
model.zero_grad()
print(time.time() - st)
Thanks! Will have a go with index_select/index_copy and make a comparison. It’s the speed of extracting and re-inserting the elements of the tensor I’m interested in! Interesting that using indices is quicker with a small number.
Well the complexity of the operation with a mask is of the size of the Tensor. Because for every entry, it needs to check the mask value.
For the indices, it just reads the values at the given indices, so the complexity is the number of indices.
The other aspect I’m wondering about in the above e.g. is that if I change the line
ixs = (torch.arange(100) > 50)
to
ixs = (torch.arange(100) > 99)
I get back to the baseline speed. So it seems the slowdown is not the indexing directly, but due to a copy being made of the elements of the matmul output where the indices that are True, and operations on this copy (hence the difference depending on how many elements of ixs are True). Is that kind of correct?!