For-loop indexing and value assignment


I am trying to iterate and value-assign over a tensor from another tensor within the forward() call of my network. This component alone causes a slowdown of 5x for my forward/backward pass but I can’t figure a better way of implementing it.

I tried utilizing a multiprocessing Pool which seemed easy (since my piece of code contains no intra-dependencies), but torch is complaining about autograd not being able to track gradients between shared non-leaf tensors.

Am I missing something? Is there any other way to get some performance back?


def recover_batch(original: FloatTensor, processed: FloatTensor, ids: Sequence[Tuple[int, range]]) -> FloatTensor:
   """ original: B x M x D  -- non-leaf, req.grad
         processed: T x N x D  -- non-leaf, req.grad
         id: a pair (b, r) where b in B, r a subrange of [0, M)
         ids: sequence of T ids identifying each item in processed -- no pair of ids is overlapping!
   for i, (b, r) in enumerate(ids):
       original[b, r] = processed[i, 0:len(r)]
   return original

I would really like to know the answer to this. Facing similar issue