My loss function involves evaluations of a custom autograd function.
Currently, I am using for loop to compute the loss for every item in a batch with this custom autograd function, which is slow.
loss = torch.stack([func(xi, some_data) for xi in x]).mean()
func is a custom autograd function,
x is a batch of tensors returned from a netwrok.
In principle, this can be vectorized or computed in parallel since each iteration is independent.
I tried to use
torch.vmap to vectorized this custom autograd function. However,
torch.vmap does not support autograd function at the moment.
(I did not write a vectorized version of autograd function directly since this function involves an iterative method to find the root of a function. For different inputs, the numbers of iterations required are different, which makes the vectorization difficult.)
Is there a way to speed up this for loop by vectorization or parallel computation?