Apply a function to a tensor columns

I’m trying to find an efficient way of applying a function to the axis of a tensor. The function could be a simple sum() and return a scalar.

So for an example for (32,3,1,16) I would get (32,3,1,1)

Right now I’m looping through a tensor and applying function. I guess this is not the most efficient way :slight_smile:

Thank you

do you mean something like this,

x.sum(dim=-1, keepdim=True)

I only need to define my own function instead of sum(x)