A fast way to apply a function across an axis

Is there an efficient way to apply a function such as torch.inverse for a matrix of size (n, m, m) where the function is applied to each of the (m, m) matrices ?

It seems that this does the job:

def apply(func, M):
    tList = [func(m) for m in torch.unbind(M, dim=0) ]
    res = torch.stack(tList, dim=0)

    return res 

apply(torch.inverse, torch.randn(100, 200, 200))

but I am wondering if there is a more efficient approach.

Tensorflow functions seem to generically achieve that, as explained here: https://www.tensorflow.org/api_docs/python/tf/matrix_inverse
but I am not sure whether their method uses a for loop or parallelizes the process.

It would be interesting to see a benchmark between the approach I see used in Pytorch and the one in Tensorflow. I will put the results of the benchmark later today.


We dont support batch inverse right now, looks like TF does. I wonder if they do anything more sophisticated than a for-loop internally.

Does tensor comprehensions allow us to apply a function across an axis much faster ?

I was also using unbind and stack as the equivalent of apply along axis in numpy. But the greatest problem was that it increased the processing by 2 times. The only way around this problem is to somehow convert the function as matrix operation. Luckily for me I was able replace the axis operation with a series of matrix multiplication.