Indexed assignemet has a significant overhead

Hi all, I’ve recently noticed that an indexed assignment (i.e. r[idx] = something) leads to quite an overhead. Does anyone have some insight into why does this happen and any mitigation of this problem?

Here a snippet to reproduce:

torch.set_num_threads(1)

r = torch.zeros(1000, 512)
m1 = torch.randn(1, 1000)
m2 = torch.rand(1, 512)

idx = torch.arange(start=0, end=1000, step=1).long()

times = []
for i in range(1000):
    start = time.perf_counter_ns()
    r = m1.t().mm(m2)
    end = (time.perf_counter_ns() - start) / 1e+6
    times.append(end)
print("nothing")
print(f"{np.mean(times)}ms pm {np.std(times)}ms\n")

times = []
for i in range(1000):
    start = time.perf_counter_ns()
    r = m1[:, idx].t().mm(m2)
    end = (time.perf_counter_ns() - start) / 1e+6
    times.append(end)
print("m1 indexing")
print(f"{np.mean(times)}ms pm {np.std(times)}ms\n")

times = []
for i in range(1000):
    start = time.perf_counter_ns()
    r[idx] = m1[:, idx].t().mm(m2)
    end = (time.perf_counter_ns() - start) / 1e+6
    times.append(end)
print("r and m1 indexing")
print(f"{np.mean(times)}ms pm {np.std(times)}ms\n")

which results in

nothing
0.8633705ms pm 0.1330834445742595ms

m1 indexing
0.9316685ms pm 0.26098909218155075ms

r and m1 indexing
1.4100757999999998ms pm 0.4815578560197726ms

Thank you all.

I’m not sure I understand the question completely, but indexing tensors would add some overhead which is expected.

@ptrblck what I’d like to do actually is to only compute the mm only for a subset of the tensor and have the resulting tensor be nonzero only for the corresponding subset.

r = torch.zeros(10, 5)
m1 = torch.randn(1, 10)
m2 = torch.rand(1, 5)

idx = torch.tensor([0, 1, 3, 4, 6])

r[idx] = m1[:, idx].t().mm(m2)

Here r will be zero at indexes 2, 5, 7, 8, 9 and nonzero at the remaining.

What I’ve noticed is that the indexed mm (m1[:, idx].t().mm(m2)) has no overhead (is actually faster than the full, non subset mm), but the indexed assignment (r[idx] = ...) has significant overhead (more or less doubles the total execution time).

Is there any way to accomplish this operation without such increase in time?