Hi Richard, thanks for your response. I figured out I could do this if I turn coeffs to a tensor. However, in my real use case coeffs is a list of lambdas. Each lambda does a binary search to locate the specific region that gives you a linear function you want to apply on your input tensor. As far as I know, you cannot have a tensor of a custom dtype such as a lambda?

In theory, if I have a function on some primitives such as a par of floats and an int in my specific case, I should be able to parallelised it and therefore, there should be a way of constructing the computation graph and speed up the calculation using something like pytorch. This can be easily done in TF using tf.map_fn but I want to stay away from TF if possible.

Edit: You can use np.vectorize with this kind of function. This in proofs that a functionality like this can be parallelised but unfortunately I can’t get gradients automatically unless I use something like vmap.

Sorry, I missed your reply (I don’t get notifications from the pytorch forums but I’ll try to check more frequently).

You’re right that we cannot shove the lambda into a Tensor. Do you have a concrete example of how you want to do the vmap over lambdas that we could take a closer look at?

Suppose that you have a loss function which you don’t know exactly but rather on some kind of a lattice (line, grid, etc.) of possible outputs. The way you train something with this loss function is that you have to interpolate the loss. Now, in every single interpolation region you have a different parametrisation of your loss function. You could store these different loss functions in a list and for every output you do a binary search to find the corresponding region and thus, which loss function to use. So this can all be done in parallel, but I need to do a binary search which you can’t formulate as a mathematical operation and thus, you can’t put it in a tensor form.

In my case it’s even slightly complicated because not all examples have the same loss function. So for every example I have a label which helps me identify which loss to use and then I need to interpolate the loss etc.

I ended up sequentially doing this for every output, saving the gradients I have and finally setting the loss to be outputs * gradients. This effectively solves the problem but is very slow at the moment as it is all done sequentially.

This is very specific to what I am doing at the moment, but it would be nice if there was a version of vmap that could apply arbitrary function f(x) to every single element of the tensor in parallel.