I built a network in pytorch, and upon profiling, saw that ~90% of the work is done in a for loop in one of my blocks. The problem is that this loop is not parallelizable, due to dependency on the previous values that were masked by mask1 (see MWE bellow).
I tried compiling it using @torch.jit.script and the speedup was negligible at ~0.5s. I am doing the minimal amount of work in the loop, anything else was vectorized. Attached is a MWE with the sizes of tensors that I am working with.
Will writing the code in C++ be better than what torchscript does? Is there any other way to considerably improve the runtime of the loop?
Thanks you.
import torch
import time
n = 11300
f = 7800
batch_size = 10
device = "cuda:3" if torch.cuda.is_available() else "cpu"
inp1 = torch.randint(0, n, size=[batch_size, n, 4], device=device)
inp2 = torch.randint(0, n, size=[batch_size, f, 3], device=device)
inp3 = torch.randint(0, f, size=[batch_size, n, 2], device=device)
inp4 = torch.randint(0, n, size=[batch_size, n, 2], device=device)
inp5 = torch.randint(0, n, size=[batch_size, n, 2], device=device)
batch_list = torch.arange(batch_size, device=device)
mask1 = torch.ones([batch_size, n], dtype=torch.bool, device=device)
mask2 = torch.ones([batch_size, n], dtype=torch.bool, device=device)
mask3 = torch.ones([batch_size, n], dtype=torch.bool, device=device)
ind=0
start_time = time.time()
for i in range(n):
if torch.all(~mask1[:, i]):
continue
batch_list_tmp = batch_list[mask1[:, i]]
mask2[mask1[:, i], i] = False
mask3[batch_list_tmp.unsqueeze(-1), inp1[:, i][:, ::2][batch_list_tmp]] = 0
mask1[batch_list_tmp, i] = False
concat_list_1 = inp5[
batch_list.unsqueeze(-1), inp4[batch_list.unsqueeze(-1), inp1[batch_list, i]].view([batch_size, -1])]
concat_list_2 = inp5[batch_list.unsqueeze(-1).unsqueeze(-1), inp4[
batch_list.unsqueeze(-1).unsqueeze(-1), concat_list_1[batch_list]].view([batch_size, 8, -1])[
batch_list]].view([batch_size, 8, -1])
closure = inp2[batch_list.unsqueeze(-1).unsqueeze(-1),
inp3[batch_list.unsqueeze(-1).unsqueeze(-1),
concat_list_2[batch_list]].view(batch_size, 8, -1)[batch_list]].view([batch_size, 8, -1])
mask1[batch_list_tmp.unsqueeze(-1), closure.view([batch_size, -1])[batch_list_tmp]] = False
end_time = time.time()
print(end_time-start_time) # takes ~5-7 seconds on server