Hi,
I am trying to iterate and value-assign over a tensor from another tensor within the forward() call of my network. This component alone causes a slowdown of 5x for my forward/backward pass but I can’t figure a better way of implementing it.
I tried utilizing a multiprocessing Pool which seemed easy (since my piece of code contains no intra-dependencies), but torch is complaining about autograd not being able to track gradients between shared non-leaf tensors.
Am I missing something? Is there any other way to get some performance back?
code:
def recover_batch(original: FloatTensor, processed: FloatTensor, ids: Sequence[Tuple[int, range]]) -> FloatTensor:
""" original: B x M x D -- non-leaf, req.grad
processed: T x N x D -- non-leaf, req.grad
id: a pair (b, r) where b in B, r a subrange of [0, M)
ids: sequence of T ids identifying each item in processed -- no pair of ids is overlapping!
"""
for i, (b, r) in enumerate(ids):
original[b, r] = processed[i, 0:len(r)]
return original