My loss function involves evaluations of a custom autograd function.
Currently, I am using for loop to compute the loss for every item in a batch with this custom autograd function, which is slow.
loss = torch.stack([func(xi, some_data) for xi in x]).mean()
where func
is a custom autograd function, x
is a batch of tensors returned from a netwrok.
In principle, this can be vectorized or computed in parallel since each iteration is independent.
I tried to use torch.vmap
to vectorized this custom autograd function. However, torch.vmap
does not support autograd function at the moment.
(I did not write a vectorized version of autograd function directly since this function involves an iterative method to find the root of a function. For different inputs, the numbers of iterations required are different, which makes the vectorization difficult.)
Is there a way to speed up this for loop by vectorization or parallel computation?