speed up unparallelizable for loop

I built a network in pytorch, and upon profiling, saw that ~90% of the work is done in a for loop in one of my blocks. The problem is that this loop is not parallelizable, due to dependency on the previous values that were masked by mask1 (see MWE bellow).

I tried compiling it using @torch.jit.script and the speedup was negligible at ~0.5s. I am doing the minimal amount of work in the loop, anything else was vectorized. Attached is a MWE with the sizes of tensors that I am working with.

Will writing the code in C++ be better than what torchscript does? Is there any other way to considerably improve the runtime of the loop?
Thanks you.

import torch
import time

n = 11300
f = 7800
batch_size = 10
device = "cuda:3" if torch.cuda.is_available() else "cpu"

inp1 = torch.randint(0, n, size=[batch_size, n, 4], device=device)
inp2 = torch.randint(0, n, size=[batch_size, f, 3], device=device)
inp3 = torch.randint(0, f, size=[batch_size, n, 2], device=device)
inp4 = torch.randint(0, n, size=[batch_size, n, 2], device=device)
inp5 = torch.randint(0, n, size=[batch_size, n, 2], device=device)

batch_list = torch.arange(batch_size, device=device)

mask1 = torch.ones([batch_size, n], dtype=torch.bool, device=device)
mask2 = torch.ones([batch_size, n], dtype=torch.bool, device=device)
mask3 = torch.ones([batch_size, n], dtype=torch.bool, device=device)

start_time = time.time()
for i in range(n):
    if torch.all(~mask1[:, i]):
    batch_list_tmp = batch_list[mask1[:, i]]
    mask2[mask1[:, i], i] = False
    mask3[batch_list_tmp.unsqueeze(-1), inp1[:, i][:, ::2][batch_list_tmp]] = 0
    mask1[batch_list_tmp, i] = False
    concat_list_1 = inp5[
        batch_list.unsqueeze(-1), inp4[batch_list.unsqueeze(-1), inp1[batch_list, i]].view([batch_size, -1])]
    concat_list_2 = inp5[batch_list.unsqueeze(-1).unsqueeze(-1), inp4[
        batch_list.unsqueeze(-1).unsqueeze(-1), concat_list_1[batch_list]].view([batch_size, 8, -1])[
        batch_list]].view([batch_size, 8, -1])
    closure = inp2[batch_list.unsqueeze(-1).unsqueeze(-1),
                              concat_list_2[batch_list]].view(batch_size, 8, -1)[batch_list]].view([batch_size, 8, -1])
    mask1[batch_list_tmp.unsqueeze(-1), closure.view([batch_size, -1])[batch_list_tmp]] = False
end_time = time.time()

print(end_time-start_time) # takes ~5-7 seconds on server