I would like to apply a function to each row of a tensor. Is there a simple and efficient way to do this without using an index for each row? I am looking for the equivalent of numpy.apply_along_axis if there is one for pytorch.
You could try (if you haven’t already):
torch.stack([ function(x_i, other_input[i]) for i, x_i in enumerate(torch.unbind(x, dim=axis), 0) ], dim=axis)
I would also like to hear about it if there’s a better way to do this!