I’m trying to find an efficient way of applying a function to the axis of a tensor. The function could be a simple sum() and return a scalar.

So for an example for (32,3,1,16) I would get (32,3,1,1)

Right now I’m looping through a tensor and applying function. I guess this is not the most efficient way

Thank you