Hi Thomas!
I do not know of any functionality built into pytorch similar to your
apply_along_axis()
. And even if there were, it would still impose
a performance penalty, as it would still be breaking up what might
have been a single, larger tensor operation into many smaller, axis-wise
tensor operations (absent some kind of hypothetical JIT compilation).
As a general rule, if you find yourself looping over a tensor, you should
see if you can recast your computation into pure (that is, loop-free)
tensor operations. Sometimes you can and sometimes you can’t.
Note that you can sometimes realize a net performance gain by getting
rid of loops, even if your loop-free approach is of higher computational
complexity.
Here is a simplistic, contrived example that illustrates replacing
apply_along_axis()
with a single pytorch tensor operation:
>>> import torch
>>> torch.__version__
'1.9.0'
>>> _ = torch.manual_seed (2021)
>>> def apply_along_axis(function, x, axis: int = 0):
... return torch.stack([
... function(x_i) for x_i in torch.unbind(x, dim=axis)
... ], dim=axis)
...
>>> def my_fn (x):
... return torch.softmax (x, 0)
...
>>> t = torch.randn (2, 3)
>>> t
tensor([[ 2.2871, 0.6413, -0.8615],
[-0.3649, -0.6931, 0.9023]])
>>> apply_along_axis (my_fn, t, 0)
tensor([[0.8092, 0.1561, 0.0347],
[0.1897, 0.1366, 0.6737]])
>>> torch.softmax (t, 1)
tensor([[0.8092, 0.1561, 0.0347],
[0.1897, 0.1366, 0.6737]])
>>> apply_along_axis (my_fn, t, 1)
tensor([[0.9341, 0.7916, 0.1463],
[0.0659, 0.2084, 0.8537]])
>>> torch.softmax (t, 0)
tensor([[0.9341, 0.7916, 0.1463],
[0.0659, 0.2084, 0.8537]])
You suggest that your use case involves having the function
you apply
be an entire neural-network model.
Although various model layers do have constraints on the shapes they
expect, the basic building blocks accept (and sometimes require) an
arbitrary batch dimension. This suggests that reworking your model so
that you don’t need apply_along_axis
could be plausible.
Two building-block examples: Linear
accepts an arbitrary number
of leading “batch” dimensions, so that’s likely to be easy. On the other
hand, Conv2d
requires a tensor of exactly four dimensions, but it’s
leading dimension is an arbitrary batch dimension, so you can use
view()
(or reshape()
) to repackaged multiple “batch” dimensions
into a single batch dimension. Thus:
>>> import torch
>>> torch.__version__
'1.9.0'
>>> torch.nn.Linear (2, 3) (torch.randn (7, 7, 5, 2)).shape
torch.Size([7, 7, 5, 3])
>>> torch.nn.Conv2d (2, 3, 3, padding = 1) (torch.randn (7, 7, 5, 2, 11, 11).view (7*7*5, 2, 11, 11)).view (7, 7, 5, 3, 11, 11).shape
torch.Size([7, 7, 5, 3, 11, 11])
Whether or not you would be able to push these kinds of techniques
through an entire model will depend on the model’s details, but there
are certainly some realistic models where you could.
Best.
K. Frank