Hi all, I’ve recently noticed that an indexed assignment (i.e. r[idx] = something
) leads to quite an overhead. Does anyone have some insight into why does this happen and any mitigation of this problem?
Here a snippet to reproduce:
torch.set_num_threads(1)
r = torch.zeros(1000, 512)
m1 = torch.randn(1, 1000)
m2 = torch.rand(1, 512)
idx = torch.arange(start=0, end=1000, step=1).long()
times = []
for i in range(1000):
start = time.perf_counter_ns()
r = m1.t().mm(m2)
end = (time.perf_counter_ns() - start) / 1e+6
times.append(end)
print("nothing")
print(f"{np.mean(times)}ms pm {np.std(times)}ms\n")
times = []
for i in range(1000):
start = time.perf_counter_ns()
r = m1[:, idx].t().mm(m2)
end = (time.perf_counter_ns() - start) / 1e+6
times.append(end)
print("m1 indexing")
print(f"{np.mean(times)}ms pm {np.std(times)}ms\n")
times = []
for i in range(1000):
start = time.perf_counter_ns()
r[idx] = m1[:, idx].t().mm(m2)
end = (time.perf_counter_ns() - start) / 1e+6
times.append(end)
print("r and m1 indexing")
print(f"{np.mean(times)}ms pm {np.std(times)}ms\n")
which results in
nothing
0.8633705ms pm 0.1330834445742595msm1 indexing
0.9316685ms pm 0.26098909218155075msr and m1 indexing
1.4100757999999998ms pm 0.4815578560197726ms
Thank you all.