I am trying to iterate and value-assign over a tensor from another tensor within the forward() call of my network. This component alone causes a slowdown of 5x for my forward/backward pass but I can’t figure a better way of implementing it.
I tried utilizing a multiprocessing Pool which seemed easy (since my piece of code contains no intra-dependencies), but torch is complaining about autograd not being able to track gradients between shared non-leaf tensors.
Am I missing something? Is there any other way to get some performance back?
def recover_batch(original: FloatTensor, processed: FloatTensor, ids: Sequence[Tuple[int, range]]) -> FloatTensor: """ original: B x M x D -- non-leaf, req.grad processed: T x N x D -- non-leaf, req.grad id: a pair (b, r) where b in B, r a subrange of [0, M) ids: sequence of T ids identifying each item in processed -- no pair of ids is overlapping! """ for i, (b, r) in enumerate(ids): original[b, r] = processed[i, 0:len(r)] return original