No worries.
Suppose that you have a loss function which you don’t know exactly but rather on some kind of a lattice (line, grid, etc.) of possible outputs. The way you train something with this loss function is that you have to interpolate the loss. Now, in every single interpolation region you have a different parametrisation of your loss function. You could store these different loss functions in a list and for every output you do a binary search to find the corresponding region and thus, which loss function to use. So this can all be done in parallel, but I need to do a binary search which you can’t formulate as a mathematical operation and thus, you can’t put it in a tensor form.
In my case it’s even slightly complicated because not all examples have the same loss function. So for every example I have a label which helps me identify which loss to use and then I need to interpolate the loss etc.
I ended up sequentially doing this for every output, saving the gradients I have and finally setting the loss to be outputs * gradients. This effectively solves the problem but is very slow at the moment as it is all done sequentially.
This is very specific to what I am doing at the moment, but it would be nice if there was a version of vmap that could apply arbitrary function f(x) to every single element of the tensor in parallel.