Pytorch equivalent of numpy.apply_along_axis

I would like to apply a function to each row of a tensor. Is there a simple and efficient way to do this without using an index for each row? I am looking for the equivalent of numpy.apply_along_axis if there is one for pytorch.

4 Likes

You could try (if you haven’t already):

torch.stack([
    function(x_i, other_input[i]) for i, x_i in enumerate(torch.unbind(x, dim=axis), 0)
], dim=axis)

I would also like to hear about it if there’s a better way to do this!

7 Likes