I do not know of any functionality built into pytorch similar to your apply_along_axis(). And even if there were, it would still impose
a performance penalty, as it would still be breaking up what might
have been a single, larger tensor operation into many smaller, axis-wise
tensor operations (absent some kind of hypothetical JIT compilation).
As a general rule, if you find yourself looping over a tensor, you should
see if you can recast your computation into pure (that is, loop-free)
tensor operations. Sometimes you can and sometimes you can’t.
Note that you can sometimes realize a net performance gain by getting
rid of loops, even if your loop-free approach is of higher computational
Here is a simplistic, contrived example that illustrates replacing apply_along_axis() with a single pytorch tensor operation:
You suggest that your use case involves having the function you apply
be an entire neural-network model.
Although various model layers do have constraints on the shapes they
expect, the basic building blocks accept (and sometimes require) an
arbitrary batch dimension. This suggests that reworking your model so
that you don’t need apply_along_axis could be plausible.
Two building-block examples: Linear accepts an arbitrary number
of leading “batch” dimensions, so that’s likely to be easy. On the other
hand, Conv2d requires a tensor of exactly four dimensions, but it’s
leading dimension is an arbitrary batch dimension, so you can use view() (or reshape()) to repackaged multiple “batch” dimensions
into a single batch dimension. Thus: