I want to parallelize a code that looks roughly like this:
L = ModuleList() x = torch.randn(batch_size, n, d) y = torch.zeros(batch_size, n, d) idx = list of list of indices (different lengths) for i, module in enumerate(L): zeros[:,idx[i],:] += module(x[:,idx[i], :])
Note that some idx[i][j] can be equal to another idx[i’][j’].
Since the loop iterations can be done in any order I’m trying to parallelize this for loop. Is there any way to do it?